[Mlir-commits] [mlir] 90a1632 - [mlir][spirv] Switch to kEmitAccessorPrefix_Predixed
Jakub Kuderski
llvmlistbot at llvm.org
Fri Sep 23 21:38:24 PDT 2022
Author: Jakub Kuderski
Date: 2022-09-24T00:37:06-04:00
New Revision: 90a1632d0b1b1310bb24c4f5101727f310a94aad
URL: https://github.com/llvm/llvm-project/commit/90a1632d0b1b1310bb24c4f5101727f310a94aad
DIFF: https://github.com/llvm/llvm-project/commit/90a1632d0b1b1310bb24c4f5101727f310a94aad.diff
LOG: [mlir][spirv] Switch to kEmitAccessorPrefix_Predixed
Fixes https://github.com/llvm/llvm-project/issues/57887
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D134580
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 1f5dc076ca132..14559bd40baba 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -72,10 +72,6 @@ def SPIRV_Dialect : Dialect {
void printAttribute(
Attribute attr, DialectAsmPrinter &printer) const override;
}];
-
- // TODO(https://github.com/llvm/llvm-project/issues/57887): Switch to
- // _Prefixed accessors.
- let emitAccessorPrefix = kEmitAccessorPrefix_Both;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index 26fadb1c206bc..17ac2ed430f26 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -65,7 +65,7 @@ def SPV_BranchOp : SPV_Op<"Branch", [
let extraClassDeclaration = [{
/// Returns the block arguments.
- operand_range getBlockArguments() { return targetOperands(); }
+ operand_range getBlockArguments() { return getTargetOperands(); }
}];
let autogenSerialization = 0;
@@ -161,22 +161,22 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [
/// Returns the number of arguments to the true target block.
unsigned getNumTrueBlockArguments() {
- return trueTargetOperands().size();
+ return getTrueTargetOperands().size();
}
/// Returns the number of arguments to the false target block.
unsigned getNumFalseBlockArguments() {
- return falseTargetOperands().size();
+ return getFalseTargetOperands().size();
}
// Iterator and range support for true target block arguments.
operand_range getTrueBlockArguments() {
- return trueTargetOperands();
+ return getTrueTargetOperands();
}
// Iterator and range support for false target block arguments.
operand_range getFalseBlockArguments() {
- return falseTargetOperands();
+ return getFalseTargetOperands();
}
private:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index ab14a70ef8227..6688f059119dd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -394,9 +394,9 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
CArg<"FlatSymbolRefAttr", "nullptr">:$initializer),
[{
$_state.addAttribute("type", type);
- $_state.addAttribute(sym_nameAttrName($_state.name), sym_name);
+ $_state.addAttribute(getSymNameAttrName($_state.name), sym_name);
if (initializer)
- $_state.addAttribute(initializerAttrName($_state.name), initializer);
+ $_state.addAttribute(getInitializerAttrName($_state.name), initializer);
}]>,
OpBuilder<(ins "TypeAttr":$type, "ArrayRef<NamedAttribute>":$namedAttrs),
[{
@@ -412,9 +412,9 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
CArg<"FlatSymbolRefAttr", "{}">:$initializer),
[{
$_state.addAttribute("type", TypeAttr::get(type));
- $_state.addAttribute(sym_nameAttrName($_state.name), $_builder.getStringAttr(sym_name));
+ $_state.addAttribute(getSymNameAttrName($_state.name), $_builder.getStringAttr(sym_name));
if (initializer)
- $_state.addAttribute(initializerAttrName($_state.name), initializer);
+ $_state.addAttribute(getInitializerAttrName($_state.name), initializer);
}]>
];
@@ -424,7 +424,7 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
let extraClassDeclaration = [{
::mlir::spirv::StorageClass storageClass() {
- return this->type().cast<::mlir::spirv::PointerType>().getStorageClass();
+ return this->getType().cast<::mlir::spirv::PointerType>().getStorageClass();
}
}];
}
@@ -509,7 +509,7 @@ def SPV_ModuleOp : SPV_Op<"module",
bool isOptionalSymbol() { return true; }
- Optional<StringRef> getName() { return sym_name(); }
+ Optional<StringRef> getName() { return getSymName(); }
static StringRef getVCETripleAttrName() { return "vce_triple"; }
}];
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 766d42bd7ba95..4c0736d121d45 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -69,12 +69,12 @@ static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
builder.getIntegerAttr(targetType, targetBits / sourceBits);
auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
auto lastDim = op->getOperand(op.getNumOperands() - 1);
- auto indices = llvm::to_vector<4>(op.indices());
+ auto indices = llvm::to_vector<4>(op.getIndices());
// There are two elements if this is a 1-D tensor.
assert(indices.size() == 2);
indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
- Type t = typeConverter.convertType(op.component_ptr().getType());
- return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
+ Type t = typeConverter.convertType(op.getComponentPtr().getType());
+ return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
}
/// Returns the shifted `targetBits`-bit value with the given offset.
@@ -371,7 +371,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
// Assume that getElementPtr() works linearizely. If it's a scalar, the method
// still returns a linearized accessing. If the accessing is not linearized,
// there will be offset issues.
- assert(accessChainOp.indices().size() == 2);
+ assert(accessChainOp.getIndices().size() == 2);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter);
Value spvLoadOp = rewriter.create<spirv::LoadOp>(
@@ -507,7 +507,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
// 6) store 32-bit value back
// The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
// 4 to step 6 are done by AtomicOr as another atomic step.
- assert(accessChainOp.indices().size() == 2);
+ assert(accessChainOp.getIndices().size() == 2);
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index 10623b63d949d..72ccb3a825183 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -174,7 +174,7 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
// Create the block for the header.
auto *header = new Block();
// Insert the header.
- loopOp.body().getBlocks().insert(getBlockIt(loopOp.body(), 1), header);
+ loopOp.getBody().getBlocks().insert(getBlockIt(loopOp.getBody(), 1), header);
// Create the new induction variable to use.
Value adapLowerBound = adaptor.getLowerBound();
@@ -197,13 +197,13 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
// Move the blocks from the forOp into the loopOp. This is the body of the
// loopOp.
- rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(),
- getBlockIt(loopOp.body(), 2));
+ rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(),
+ getBlockIt(loopOp.getBody(), 2));
SmallVector<Value, 8> args(1, adaptor.getLowerBound());
args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
// Branch into it from the entry.
- rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
+ rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
rewriter.create<spirv::BranchOp>(loc, header, args);
// Generate the rest of the loop header.
@@ -252,12 +252,12 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
auto selectionOp =
rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
auto *mergeBlock =
- rewriter.createBlock(&selectionOp.body(), selectionOp.body().end());
+ rewriter.createBlock(&selectionOp.getBody(), selectionOp.getBody().end());
rewriter.create<spirv::MergeOp>(loc);
OpBuilder::InsertionGuard guard(rewriter);
auto *selectionHeaderBlock =
- rewriter.createBlock(&selectionOp.body().front());
+ rewriter.createBlock(&selectionOp.getBody().front());
// Inline `then` region before the merge block and branch to it.
auto &thenRegion = ifOp.getThenRegion();
@@ -367,12 +367,12 @@ WhileOpConversion::matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
return failure();
// Move the while before block as the initial loop header block.
- rewriter.inlineRegionBefore(beforeRegion, loopOp.body(),
- getBlockIt(loopOp.body(), 1));
+ rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(),
+ getBlockIt(loopOp.getBody(), 1));
// Move the while after block as the initial loop body block.
- rewriter.inlineRegionBefore(afterRegion, loopOp.body(),
- getBlockIt(loopOp.body(), 2));
+ rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(),
+ getBlockIt(loopOp.getBody(), 2));
// Jump from the loop entry block to the loop header block.
rewriter.setInsertionPointToEnd(&entryBlock);
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index c09ad0eb61729..e97774832f568 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -89,7 +89,7 @@ createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
op->getAttrOfType<IntegerAttr>(descriptorSetName());
IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
- kernelModuleName.str(), op.sym_name().str(),
+ kernelModuleName.str(), op.getSymName().str(),
std::to_string(descriptorSet.getInt()),
std::to_string(binding.getInt()));
}
@@ -126,14 +126,14 @@ static LogicalResult getKernelGlobalVariables(
/// Encodes the SPIR-V module's symbolic name into the name of the entry point
/// function.
static LogicalResult encodeKernelName(spirv::ModuleOp module) {
- StringRef spvModuleName = *module.sym_name();
+ StringRef spvModuleName = *module.getSymName();
// We already know that the module contains exactly one entry point function
// based on `getKernelGlobalVariables()` call. Update this function's name
// to:
// {spv_module_name}_{function_name}
auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin();
- StringRef funcName = entryPoint.fn();
- auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.fnAttr());
+ StringRef funcName = entryPoint.getFn();
+ auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.getFnAttr());
StringAttr newFuncName =
StringAttr::get(module->getContext(), spvModuleName + "_" + funcName);
if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
@@ -236,7 +236,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
// LLVM dialect global variable.
spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
auto pointeeType =
- spirvGlobal.type().cast<spirv::PointerType>().getPointeeType();
+ spirvGlobal.getType().cast<spirv::PointerType>().getPointeeType();
auto dstGlobalType = typeConverter->convertType(pointeeType);
if (!dstGlobalType)
return failure();
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 1d6ae56a6c96e..7a3c873426435 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -228,14 +228,14 @@ static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
if (!dstType)
return failure();
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
- loadOp, dstType, spirv::LoadOpAdaptor(operands).ptr(), alignment,
+ loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
isVolatile, isNonTemporal);
return success();
}
auto storeOp = cast<spirv::StoreOp>(op);
spirv::StoreOpAdaptor adaptor(operands);
- rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(),
- adaptor.ptr(), alignment,
+ rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
+ adaptor.getPtr(), alignment,
isVolatile, isNonTemporal);
return success();
}
@@ -305,19 +305,19 @@ class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
LogicalResult
matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto dstType = typeConverter.convertType(op.component_ptr().getType());
+ auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
if (!dstType)
return failure();
// To use GEP we need to add a first 0 index to go through the pointer.
- auto indices = llvm::to_vector<4>(adaptor.indices());
- Type indexType = op.indices().front().getType();
+ auto indices = llvm::to_vector<4>(adaptor.getIndices());
+ Type indexType = op.getIndices().front().getType();
auto llvmIndexType = typeConverter.convertType(indexType);
if (!llvmIndexType)
return failure();
Value zero = rewriter.create<LLVM::ConstantOp>(
op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
indices.insert(indices.begin(), zero);
- rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, adaptor.base_ptr(),
+ rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, adaptor.getBasePtr(),
indices);
return success();
}
@@ -330,10 +330,10 @@ class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
LogicalResult
matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto dstType = typeConverter.convertType(op.pointer().getType());
+ auto dstType = typeConverter.convertType(op.getPointer().getType());
if (!dstType)
return failure();
- rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.variable());
+ rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.getVariable());
return success();
}
};
@@ -353,9 +353,9 @@ class BitFieldInsertPattern
Location loc = op.getLoc();
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
- Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
+ Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
typeConverter, rewriter);
- Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
+ Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
typeConverter, rewriter);
// Create a mask with bits set outside [Offset, Offset + Count - 1].
@@ -372,9 +372,9 @@ class BitFieldInsertPattern
// Extract unchanged bits from the `Base` that are outside of
// [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
Value baseAndMask =
- rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask);
+ rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
Value insertShiftedByOffset =
- rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset);
+ rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
insertShiftedByOffset);
return success();
@@ -408,14 +408,14 @@ class ConstantScalarAndVectorPattern
auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
if (srcType.isa<VectorType>()) {
- auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
+ auto dstElementsAttr = constOp.getValue().cast<DenseIntElementsAttr>();
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
constOp, dstType,
dstElementsAttr.mapValues(
signlessType, [&](const APInt &value) { return value; }));
return success();
}
- auto srcAttr = constOp.value().cast<IntegerAttr>();
+ auto srcAttr = constOp.getValue().cast<IntegerAttr>();
auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
return success();
@@ -441,9 +441,9 @@ class BitFieldSExtractPattern
Location loc = op.getLoc();
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
- Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
+ Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
typeConverter, rewriter);
- Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
+ Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
typeConverter, rewriter);
// Create a constant that holds the size of the `Base`.
@@ -468,7 +468,7 @@ class BitFieldSExtractPattern
Value amountToShiftLeft =
rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
- loc, dstType, op.base(), amountToShiftLeft);
+ loc, dstType, op.getBase(), amountToShiftLeft);
// Shift the result right, filling the bits with the sign bit.
Value amountToShiftRight =
@@ -494,9 +494,9 @@ class BitFieldUExtractPattern
Location loc = op.getLoc();
// Process `Offset` and `Count`: broadcast and extend/truncate if needed.
- Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
+ Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
typeConverter, rewriter);
- Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
+ Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
typeConverter, rewriter);
// Create a mask with bits set at [0, Count - 1].
@@ -508,7 +508,7 @@ class BitFieldUExtractPattern
// Shift `Base` by `Offset` and apply the mask on it.
Value shiftedBase =
- rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset);
+ rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
return success();
}
@@ -538,20 +538,20 @@ class BranchConditionalConversionPattern
ConversionPatternRewriter &rewriter) const override {
// If branch weights exist, map them to 32-bit integer vector.
ElementsAttr branchWeights = nullptr;
- if (auto weights = op.branch_weights()) {
+ if (auto weights = op.getBranchWeights()) {
VectorType weightType = VectorType::get(2, rewriter.getI32Type());
branchWeights = DenseElementsAttr::get(weightType, weights->getValue());
}
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
- op, op.condition(), op.getTrueBlockArguments(),
+ op, op.getCondition(), op.getTrueBlockArguments(),
op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
op.getFalseBlock());
return success();
}
};
-/// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type
+/// Converts `spv.getCompositeExtract` to `llvm.extractvalue` if the container type
/// is an aggregate type (struct or array). Otherwise, converts to
/// `llvm.extractelement` that operates on vectors.
class CompositeExtractPattern
@@ -566,23 +566,23 @@ class CompositeExtractPattern
if (!dstType)
return failure();
- Type containerType = op.composite().getType();
+ Type containerType = op.getComposite().getType();
if (containerType.isa<VectorType>()) {
Location loc = op.getLoc();
- IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
+ IntegerAttr value = op.getIndices()[0].cast<IntegerAttr>();
Value index = createI32ConstantOf(loc, rewriter, value.getInt());
rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- op, dstType, adaptor.composite(), index);
+ op, dstType, adaptor.getComposite(), index);
return success();
}
rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
- op, adaptor.composite(), LLVM::convertArrayToIndices(op.indices()));
+ op, adaptor.getComposite(), LLVM::convertArrayToIndices(op.getIndices()));
return success();
}
};
-/// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type
+/// Converts `spv.getCompositeInsert` to `llvm.insertvalue` if the container type
/// is an aggregate type (struct or array). Otherwise, converts to
/// `llvm.insertelement` that operates on vectors.
class CompositeInsertPattern
@@ -597,19 +597,19 @@ class CompositeInsertPattern
if (!dstType)
return failure();
- Type containerType = op.composite().getType();
+ Type containerType = op.getComposite().getType();
if (containerType.isa<VectorType>()) {
Location loc = op.getLoc();
- IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
+ IntegerAttr value = op.getIndices()[0].cast<IntegerAttr>();
Value index = createI32ConstantOf(loc, rewriter, value.getInt());
rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- op, dstType, adaptor.composite(), adaptor.object(), index);
+ op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
return success();
}
rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
- op, adaptor.composite(), adaptor.object(),
- LLVM::convertArrayToIndices(op.indices()));
+ op, adaptor.getComposite(), adaptor.getObject(),
+ LLVM::convertArrayToIndices(op.getIndices()));
return success();
}
};
@@ -647,14 +647,14 @@ class ExecutionModePattern
// this entry point's execution mode. We set it to be:
// __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
ModuleOp module = op->getParentOfType<ModuleOp>();
- spirv::ExecutionModeAttr executionModeAttr = op.execution_modeAttr();
+ spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
std::string moduleName;
if (module.getName().has_value())
- moduleName = "_" + module.getName().value().str();
+ moduleName = "_" + module.getName()->str();
else
moduleName = "";
std::string executionModeInfoName = llvm::formatv(
- "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.fn().str(),
+ "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
static_cast<uint32_t>(executionModeAttr.getValue()));
MLIRContext *context = rewriter.getContext();
@@ -669,7 +669,7 @@ class ExecutionModePattern
auto llvmI32Type = IntegerType::get(context, 32);
SmallVector<Type, 2> fields;
fields.push_back(llvmI32Type);
- ArrayAttr values = op.values();
+ ArrayAttr values = op.getValues();
if (!values.empty()) {
auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
fields.push_back(arrayType);
@@ -722,10 +722,10 @@ class GlobalVariablePattern
ConversionPatternRewriter &rewriter) const override {
// Currently, there is no support of initialization with a constant value in
// SPIR-V dialect. Specialization constants are not considered as well.
- if (op.initializer())
+ if (op.getInitializer())
return failure();
- auto srcType = op.type().cast<spirv::PointerType>();
+ auto srcType = op.getType().cast<spirv::PointerType>();
auto dstType = typeConverter.convertType(srcType.getPointeeType());
if (!dstType)
return failure();
@@ -759,12 +759,12 @@ class GlobalVariablePattern
? LLVM::Linkage::Private
: LLVM::Linkage::External;
auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
- op, dstType, isConstant, linkage, op.sym_name(), Attribute(),
+ op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
/*alignment=*/0);
// Attach location attribute if applicable
- if (op.locationAttr())
- newGlobalOp->setAttr(op.locationAttrName(), op.locationAttr());
+ if (op.getLocationAttr())
+ newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
return success();
}
@@ -781,7 +781,7 @@ class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type fromType = operation.operand().getType();
+ Type fromType = operation.getOperand().getType();
Type toType = operation.getType();
auto dstType = this->typeConverter.convertType(toType);
@@ -839,8 +839,8 @@ class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
return failure();
rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
- operation, dstType, predicate, operation.operand1(),
- operation.operand2());
+ operation, dstType, predicate, operation.getOperand1(),
+ operation.getOperand2());
return success();
}
};
@@ -860,8 +860,8 @@ class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
return failure();
rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
- operation, dstType, predicate, operation.operand1(),
- operation.operand2());
+ operation, dstType, predicate, operation.getOperand1(),
+ operation.getOperand2());
return success();
}
};
@@ -881,7 +881,7 @@ class InverseSqrtPattern
Location loc = op.getLoc();
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
- Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
+ Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
return success();
}
@@ -896,20 +896,20 @@ class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
LogicalResult
matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (!op.memory_access()) {
+ if (!op.getMemoryAccess()) {
return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
this->typeConverter, /*alignment=*/0,
/*isVolatile=*/false,
/*isNonTemporal=*/false);
}
- auto memoryAccess = *op.memory_access();
+ auto memoryAccess = *op.getMemoryAccess();
switch (memoryAccess) {
case spirv::MemoryAccess::Aligned:
case spirv::MemoryAccess::None:
case spirv::MemoryAccess::Nontemporal:
case spirv::MemoryAccess::Volatile: {
unsigned alignment =
- memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0;
+ memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
@@ -946,7 +946,7 @@ class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
srcType.template cast<VectorType>(), minusOne))
: rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
- notOp.operand(), mask);
+ notOp.getOperand(), mask);
return success();
}
};
@@ -1047,7 +1047,7 @@ class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// There is no support of loop control at the moment.
- if (loopOp.loop_control() != spirv::LoopControl::None)
+ if (loopOp.getLoopControl() != spirv::LoopControl::None)
return failure();
Location loc = loopOp.getLoc();
@@ -1077,7 +1077,7 @@ class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
rewriter.setInsertionPointToEnd(mergeBlock);
rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
- rewriter.inlineRegionBefore(loopOp.body(), endBlock);
+ rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
rewriter.replaceOp(loopOp, endBlock->getArguments());
return success();
}
@@ -1096,14 +1096,14 @@ class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
// There is no support for `Flatten` or `DontFlatten` selection control at
// the moment. This are just compiler hints and can be performed during the
// optimization passes.
- if (op.selection_control() != spirv::SelectionControl::None)
+ if (op.getSelectionControl() != spirv::SelectionControl::None)
return failure();
// `spv.mlir.selection` should have at least two blocks: one selection
// header block and one merge block. If no blocks are present, or control
// flow branches straight to merge block (two blocks are present), the op is
// redundant and it is erased.
- if (op.body().getBlocks().size() <= 2) {
+ if (op.getBody().getBlocks().size() <= 2) {
rewriter.eraseOp(op);
return success();
}
@@ -1140,11 +1140,11 @@ class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
Block *trueBlock = condBrOp.getTrueBlock();
Block *falseBlock = condBrOp.getFalseBlock();
rewriter.setInsertionPointToEnd(currentBlock);
- rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock,
- condBrOp.trueTargetOperands(), falseBlock,
- condBrOp.falseTargetOperands());
+ rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
+ condBrOp.getTrueTargetOperands(), falseBlock,
+ condBrOp.getFalseTargetOperands());
- rewriter.inlineRegionBefore(op.body(), continueBlock);
+ rewriter.inlineRegionBefore(op.getBody(), continueBlock);
rewriter.replaceOp(op, continueBlock->getArguments());
return success();
}
@@ -1167,8 +1167,8 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
if (!dstType)
return failure();
- Type op1Type = operation.operand1().getType();
- Type op2Type = operation.operand2().getType();
+ Type op1Type = operation.getOperand1().getType();
+ Type op2Type = operation.getOperand2().getType();
if (op1Type == op2Type) {
rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
@@ -1180,13 +1180,13 @@ class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
Value extended;
if (isUnsignedIntegerOrVector(op2Type)) {
extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
- adaptor.operand2());
+ adaptor.getOperand2());
} else {
extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
- adaptor.operand2());
+ adaptor.getOperand2());
}
Value result = rewriter.template create<LLVMOp>(
- loc, dstType, adaptor.operand1(), extended);
+ loc, dstType, adaptor.getOperand1(), extended);
rewriter.replaceOp(operation, result);
return success();
}
@@ -1204,8 +1204,8 @@ class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
return failure();
Location loc = tanOp.getLoc();
- Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand());
- Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand());
+ Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
+ Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
return success();
}
@@ -1232,7 +1232,7 @@ class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
Location loc = tanhOp.getLoc();
Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
Value multiplied =
- rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
+ rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
Value numerator =
@@ -1255,7 +1255,7 @@ class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
auto srcType = varOp.getType();
// Initialization is supported for scalars and vectors only.
auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
- auto init = varOp.initializer();
+ auto init = varOp.getInitializer();
if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
return failure();
@@ -1270,7 +1270,7 @@ class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
return success();
}
Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
- rewriter.create<LLVM::StoreOp>(loc, adaptor.initializer(), allocated);
+ rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
rewriter.replaceOp(varOp, allocated);
return success();
}
@@ -1305,7 +1305,7 @@ class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
// Convert SPIR-V Function Control to equivalent LLVM function attribute
MLIRContext *context = funcOp.getContext();
- switch (funcOp.function_control()) {
+ switch (funcOp.getFunctionControl()) {
#define DISPATCH(functionControl, llvmAttr) \
case functionControl: \
newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
@@ -1374,9 +1374,9 @@ class VectorShufflePattern
matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- auto components = adaptor.components();
- auto vector1 = adaptor.vector1();
- auto vector2 = adaptor.vector2();
+ auto components = adaptor.getComponents();
+ auto vector1 = adaptor.getVector1();
+ auto vector2 = adaptor.getVector2();
int vector1Size = vector1.getType().cast<VectorType>().getNumElements();
int vector2Size = vector2.getType().cast<VectorType>().getNumElements();
if (vector1Size == vector2Size) {
@@ -1589,8 +1589,8 @@ void mlir::encodeBindAttribute(ModuleOp module) {
// SPIR-V module has a name, add it at the beginning.
auto moduleAndName =
spvModule.getName().has_value()
- ? spvModule.getName().value().str() + "_" + op.sym_name().str()
- : op.sym_name().str();
+ ? spvModule.getName()->str() + "_" + op.getSymName().str()
+ : op.getSymName().str();
std::string name =
llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
std::to_string(descriptorSet.getInt()),
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 8db35960dc668..fd197a40a7256 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -88,19 +88,19 @@ struct CombineChainedAccessChain
LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
PatternRewriter &rewriter) const override {
auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
- accessChainOp.base_ptr().getDefiningOp());
+ accessChainOp.getBasePtr().getDefiningOp());
if (!parentAccessChainOp) {
return failure();
}
// Combine indices.
- SmallVector<Value, 4> indices(parentAccessChainOp.indices());
- indices.append(accessChainOp.indices().begin(),
- accessChainOp.indices().end());
+ SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
+ indices.append(accessChainOp.getIndices().begin(),
+ accessChainOp.getIndices().end());
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
- accessChainOp, parentAccessChainOp.base_ptr(), indices);
+ accessChainOp, parentAccessChainOp.getBasePtr(), indices);
return success();
}
@@ -126,23 +126,24 @@ void spirv::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//
OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
- if (auto insertOp = composite().getDefiningOp<spirv::CompositeInsertOp>()) {
- if (indices() == insertOp.indices())
- return insertOp.object();
+ if (auto insertOp =
+ getComposite().getDefiningOp<spirv::CompositeInsertOp>()) {
+ if (getIndices() == insertOp.getIndices())
+ return insertOp.getObject();
}
if (auto constructOp =
- composite().getDefiningOp<spirv::CompositeConstructOp>()) {
+ getComposite().getDefiningOp<spirv::CompositeConstructOp>()) {
auto type = constructOp.getType().cast<spirv::CompositeType>();
- if (indices().size() == 1 &&
- constructOp.constituents().size() == type.getNumElements()) {
- auto i = indices().begin()->cast<IntegerAttr>();
- return constructOp.constituents()[i.getValue().getSExtValue()];
+ if (getIndices().size() == 1 &&
+ constructOp.getConstituents().size() == type.getNumElements()) {
+ auto i = getIndices().begin()->cast<IntegerAttr>();
+ return constructOp.getConstituents()[i.getValue().getSExtValue()];
}
}
auto indexVector =
- llvm::to_vector<8>(llvm::map_range(indices(), [](Attribute attr) {
+ llvm::to_vector<8>(llvm::map_range(getIndices(), [](Attribute attr) {
return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
}));
return extractCompositeElement(operands[0], indexVector);
@@ -154,7 +155,7 @@ OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "spv.Constant has no operands");
- return value();
+ return getValue();
}
//===----------------------------------------------------------------------===//
@@ -164,8 +165,8 @@ OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "spv.IAdd expects two operands");
// x + 0 = x
- if (matchPattern(operand2(), m_Zero()))
- return operand1();
+ if (matchPattern(getOperand2(), m_Zero()))
+ return getOperand1();
// According to the SPIR-V spec:
//
@@ -183,11 +184,11 @@ OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "spv.IMul expects two operands");
// x * 0 == 0
- if (matchPattern(operand2(), m_Zero()))
- return operand2();
+ if (matchPattern(getOperand2(), m_Zero()))
+ return getOperand2();
// x * 1 = x
- if (matchPattern(operand2(), m_One()))
- return operand1();
+ if (matchPattern(getOperand2(), m_One()))
+ return getOperand1();
// According to the SPIR-V spec:
//
@@ -204,7 +205,7 @@ OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
// x - x = 0
- if (operand1() == operand2())
+ if (getOperand1() == getOperand2())
return Builder(getContext()).getIntegerAttr(getType(), 0);
// According to the SPIR-V spec:
@@ -226,7 +227,7 @@ OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
// x && true = x
if (rhs.value())
- return operand1();
+ return getOperand1();
// x && false = false
if (!rhs.value())
@@ -262,7 +263,7 @@ OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
// x || false = x
if (!rhs.value())
- return operand1();
+ return getOperand1();
}
return Attribute();
@@ -339,8 +340,8 @@ struct ConvertSelectionOpToSelect
cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
auto selectOp = rewriter.create<spirv::SelectOp>(
- selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
- trueValue, falseValue);
+ selectionOp.getLoc(), trueValue.getType(),
+ brConditionalOp.getCondition(), trueValue, falseValue);
rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
selectOp.getResult(), storeOpAttributes);
@@ -371,13 +372,13 @@ struct ConvertSelectionOpToSelect
// Returns a source value for the given block.
Value getSrcValue(Block *block) const {
auto storeOp = cast<spirv::StoreOp>(block->front());
- return storeOp.value();
+ return storeOp.getValue();
}
// Returns a destination value for the given block.
Value getDstPtr(Block *block) const {
auto storeOp = cast<spirv::StoreOp>(block->front());
- return storeOp.ptr();
+ return storeOp.getPtr();
}
};
@@ -406,14 +407,14 @@ LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
// "Before version 1.4, Result Type must be a pointer, scalar, or vector.
// Starting with version 1.4, Result Type can additionally be a composite type
// other than a vector."
- bool isScalarOrVector = trueBrStoreOp.value()
+ bool isScalarOrVector = trueBrStoreOp.getValue()
.getType()
.cast<spirv::SPIRVType>()
.isScalarOrVector();
// Check that each `spv.Store` uses the same pointer, memory access
// attributes and a valid type of the value.
- if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
+ if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
!isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
return failure();
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 6333f87804563..5703d7e6df813 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -106,7 +106,7 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
// Replace the values directly with the return operands.
assert(valuesToRepl.size() == 1 &&
"spv.ReturnValue expected to only handle one result");
- valuesToRepl.front().replaceAllUsesWith(retValOp.value());
+ valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
}
};
} // namespace
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 569fb61ecc14a..3582d45696023 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -136,7 +136,7 @@ static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
if (!constOp) {
return failure();
}
- auto valueAttr = constOp.value();
+ auto valueAttr = constOp.getValue();
auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
if (!integerValueAttr) {
return failure();
@@ -313,7 +313,7 @@ static void printMemoryAccessAttribute(
Optional<uint32_t> alignmentAttrValue = None) {
// Print optional memory access attribute.
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
- : memoryOp.memory_access())) {
+ : memoryOp.getMemoryAccess())) {
elidedAttrs.push_back(kMemoryAccessAttrName);
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
@@ -321,7 +321,7 @@ static void printMemoryAccessAttribute(
if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
// Print integer alignment attribute.
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
- : memoryOp.alignment())) {
+ : memoryOp.getAlignment())) {
elidedAttrs.push_back(kAlignmentAttrName);
printer << ", " << alignment;
}
@@ -346,7 +346,7 @@ static void printSourceMemoryAccessAttribute(
// Print optional memory access attribute.
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
- : memoryOp.memory_access())) {
+ : memoryOp.getMemoryAccess())) {
elidedAttrs.push_back(kSourceMemoryAccessAttrName);
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
@@ -354,7 +354,7 @@ static void printSourceMemoryAccessAttribute(
if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
// Print integer alignment attribute.
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
- : memoryOp.alignment())) {
+ : memoryOp.getAlignment())) {
elidedAttrs.push_back(kSourceAlignmentAttrName);
printer << ", " << alignment;
}
@@ -1086,17 +1086,17 @@ ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser,
template <typename Op>
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
- printer << ' ' << op.base_ptr() << '[' << indices
- << "] : " << op.base_ptr().getType() << ", " << indices.getTypes();
+ printer << ' ' << op.getBasePtr() << '[' << indices
+ << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
}
void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
- printAccessChain(*this, indices(), printer);
+ printAccessChain(*this, getIndices(), printer);
}
template <typename Op>
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
- auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
+ auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
indices, accessChainOp.getLoc());
if (!resultType)
return failure();
@@ -1116,7 +1116,7 @@ static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
}
LogicalResult spirv::AccessChainOp::verify() {
- return verifyAccessChain(*this, indices());
+ return verifyAccessChain(*this, getIndices());
}
//===----------------------------------------------------------------------===//
@@ -1125,17 +1125,17 @@ LogicalResult spirv::AccessChainOp::verify() {
void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
spirv::GlobalVariableOp var) {
- build(builder, state, var.type(), SymbolRefAttr::get(var));
+ build(builder, state, var.getType(), SymbolRefAttr::get(var));
}
LogicalResult spirv::AddressOfOp::verify() {
auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
- variableAttr()));
+ getVariableAttr()));
if (!varOp) {
return emitOpError("expected spv.GlobalVariable symbol");
}
- if (pointer().getType() != varOp.type()) {
+ if (getPointer().getType() != varOp.getType()) {
return emitOpError(
"result type mismatch with the referenced global variable's type");
}
@@ -1144,10 +1144,10 @@ LogicalResult spirv::AddressOfOp::verify() {
template <typename T>
static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
- printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \""
- << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
- << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
- << atomOp.getOperands() << " : " << atomOp.pointer().getType();
+ printer << " \"" << stringifyScope(atomOp.getMemoryScope()) << "\" \""
+ << stringifyMemorySemantics(atomOp.getEqualSemantics()) << "\" \""
+ << stringifyMemorySemantics(atomOp.getUnequalSemantics()) << "\" "
+ << atomOp.getOperands() << " : " << atomOp.getPointer().getType();
}
static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
@@ -1188,18 +1188,18 @@ static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) {
// "The type of Value must be the same as Result Type. The type of the value
// pointed to by Pointer must be the same as Result Type. This type must also
// match the type of Comparator."
- if (atomOp.getType() != atomOp.value().getType())
+ if (atomOp.getType() != atomOp.getValue().getType())
return atomOp.emitOpError("value operand must have the same type as the op "
"result, but found ")
- << atomOp.value().getType() << " vs " << atomOp.getType();
+ << atomOp.getValue().getType() << " vs " << atomOp.getType();
- if (atomOp.getType() != atomOp.comparator().getType())
+ if (atomOp.getType() != atomOp.getComparator().getType())
return atomOp.emitOpError(
"comparator operand must have the same type as the op "
"result, but found ")
- << atomOp.comparator().getType() << " vs " << atomOp.getType();
+ << atomOp.getComparator().getType() << " vs " << atomOp.getType();
- Type pointeeType = atomOp.pointer()
+ Type pointeeType = atomOp.getPointer()
.getType()
.template cast<spirv::PointerType>()
.getPointeeType();
@@ -1268,9 +1268,9 @@ void spirv::AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
void spirv::AtomicExchangeOp::print(OpAsmPrinter &printer) {
- printer << " \"" << stringifyScope(memory_scope()) << "\" \""
- << stringifyMemorySemantics(semantics()) << "\" " << getOperands()
- << " : " << pointer().getType();
+ printer << " \"" << stringifyScope(getMemoryScope()) << "\" \""
+ << stringifyMemorySemantics(getSemantics()) << "\" " << getOperands()
+ << " : " << getPointer().getType();
}
ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser,
@@ -1302,13 +1302,13 @@ ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser,
}
LogicalResult spirv::AtomicExchangeOp::verify() {
- if (getType() != value().getType())
+ if (getType() != getValue().getType())
return emitOpError("value operand must have the same type as the op "
"result, but found ")
- << value().getType() << " vs " << getType();
+ << getValue().getType() << " vs " << getType();
Type pointeeType =
- pointer().getType().cast<spirv::PointerType>().getPointeeType();
+ getPointer().getType().cast<spirv::PointerType>().getPointeeType();
if (getType() != pointeeType)
return emitOpError("pointer operand's pointee type must have the same "
"as the op result type, but found ")
@@ -1500,8 +1500,8 @@ void spirv::AtomicXorOp::print(OpAsmPrinter &p) {
LogicalResult spirv::BitcastOp::verify() {
// TODO: The SPIR-V spec validation rules are
diff erent for
diff erent
// versions.
- auto operandType = operand().getType();
- auto resultType = result().getType();
+ auto operandType = getOperand().getType();
+ auto resultType = getResult().getType();
if (operandType == resultType) {
return emitError("result type must be
diff erent from operand type");
}
@@ -1530,8 +1530,8 @@ LogicalResult spirv::BitcastOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::PtrCastToGenericOp::verify() {
- auto operandType = pointer().getType().cast<spirv::PointerType>();
- auto resultType = result().getType().cast<spirv::PointerType>();
+ auto operandType = getPointer().getType().cast<spirv::PointerType>();
+ auto resultType = getResult().getType().cast<spirv::PointerType>();
spirv::StorageClass operandStorage = operandType.getStorageClass();
if (operandStorage != spirv::StorageClass::Workgroup &&
@@ -1558,8 +1558,8 @@ LogicalResult spirv::PtrCastToGenericOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::GenericCastToPtrOp::verify() {
- auto operandType = pointer().getType().cast<spirv::PointerType>();
- auto resultType = result().getType().cast<spirv::PointerType>();
+ auto operandType = getPointer().getType().cast<spirv::PointerType>();
+ auto resultType = getResult().getType().cast<spirv::PointerType>();
spirv::StorageClass operandStorage = operandType.getStorageClass();
if (operandStorage != spirv::StorageClass::Generic)
@@ -1586,8 +1586,8 @@ LogicalResult spirv::GenericCastToPtrOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::GenericCastToPtrExplicitOp::verify() {
- auto operandType = pointer().getType().cast<spirv::PointerType>();
- auto resultType = result().getType().cast<spirv::PointerType>();
+ auto operandType = getPointer().getType().cast<spirv::PointerType>();
+ auto resultType = getResult().getType().cast<spirv::PointerType>();
spirv::StorageClass operandStorage = operandType.getStorageClass();
if (operandStorage != spirv::StorageClass::Generic)
@@ -1615,7 +1615,7 @@ LogicalResult spirv::GenericCastToPtrExplicitOp::verify() {
SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
- return SuccessorOperands(0, targetOperandsMutable());
+ return SuccessorOperands(0, getTargetOperandsMutable());
}
//===----------------------------------------------------------------------===//
@@ -1625,8 +1625,9 @@ SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
SuccessorOperands
spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
assert(index < 2 && "invalid successor index");
- return SuccessorOperands(index == kTrueIndex ? trueTargetOperandsMutable()
- : falseTargetOperandsMutable());
+ return SuccessorOperands(index == kTrueIndex
+ ? getTrueTargetOperandsMutable()
+ : getFalseTargetOperandsMutable());
}
ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,
@@ -1681,9 +1682,9 @@ ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,
}
void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) {
- printer << ' ' << condition();
+ printer << ' ' << getCondition();
- if (auto weights = branch_weights()) {
+ if (auto weights = getBranchWeights()) {
printer << " [";
llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
printer << a.cast<IntegerAttr>().getInt();
@@ -1698,7 +1699,7 @@ void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) {
}
LogicalResult spirv::BranchConditionalOp::verify() {
- if (auto weights = branch_weights()) {
+ if (auto weights = getBranchWeights()) {
if (weights->getValue().size() != 2) {
return emitOpError("must have exactly two branch weights");
}
@@ -1717,7 +1718,7 @@ LogicalResult spirv::BranchConditionalOp::verify() {
LogicalResult spirv::CompositeConstructOp::verify() {
auto cType = getType().cast<spirv::CompositeType>();
- operand_range constituents = this->constituents();
+ operand_range constituents = this->getConstituents();
if (auto coopType = cType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
if (constituents.size() != 1)
@@ -1828,13 +1829,14 @@ ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
}
void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
- printer << ' ' << composite() << indices() << " : " << composite().getType();
+ printer << ' ' << getComposite() << getIndices() << " : "
+ << getComposite().getType();
}
LogicalResult spirv::CompositeExtractOp::verify() {
- auto indicesArrayAttr = indices().dyn_cast<ArrayAttr>();
+ auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>();
auto resultType =
- getElementType(composite().getType(), indicesArrayAttr, getLoc());
+ getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
if (!resultType)
return failure();
@@ -1875,29 +1877,30 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
}
LogicalResult spirv::CompositeInsertOp::verify() {
- auto indicesArrayAttr = indices().dyn_cast<ArrayAttr>();
+ auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>();
auto objectType =
- getElementType(composite().getType(), indicesArrayAttr, getLoc());
+ getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
if (!objectType)
return failure();
- if (objectType != object().getType()) {
+ if (objectType != getObject().getType()) {
return emitOpError("object operand type should be ")
- << objectType << ", but found " << object().getType();
+ << objectType << ", but found " << getObject().getType();
}
- if (composite().getType() != getType()) {
+ if (getComposite().getType() != getType()) {
return emitOpError("result type should be the same as "
"the composite type, but found ")
- << composite().getType() << " vs " << getType();
+ << getComposite().getType() << " vs " << getType();
}
return success();
}
void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
- printer << " " << object() << ", " << composite() << indices() << " : "
- << object().getType() << " into " << composite().getType();
+ printer << " " << getObject() << ", " << getComposite() << getIndices()
+ << " : " << getObject().getType() << " into "
+ << getComposite().getType();
}
//===----------------------------------------------------------------------===//
@@ -1922,7 +1925,7 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
}
void spirv::ConstantOp::print(OpAsmPrinter &printer) {
- printer << ' ' << value();
+ printer << ' ' << getValue();
if (getType().isa<spirv::ArrayType>())
printer << " : " << getType();
}
@@ -1989,7 +1992,7 @@ LogicalResult spirv::ConstantOp::verify() {
// ODS already generates checks to make sure the result type is valid. We just
// need to additionally check that the value's attribute type is consistent
// with the result type.
- return verifyConstantType(*this, valueAttr(), getType());
+ return verifyConstantType(*this, getValueAttr(), getType());
}
bool spirv::ConstantOp::isBuildableWith(Type type) {
@@ -2081,7 +2084,7 @@ void mlir::spirv::ConstantOp::getAsmResultNames(
IntegerType intTy = type.dyn_cast<IntegerType>();
- if (IntegerAttr intCst = value().dyn_cast<IntegerAttr>()) {
+ if (IntegerAttr intCst = getValue().dyn_cast<IntegerAttr>()) {
if (intTy && intTy.getWidth() == 1) {
return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
}
@@ -2115,7 +2118,7 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
SmallString<32> specialNameBuffer;
llvm::raw_svector_ostream specialName(specialNameBuffer);
- specialName << variable() << "_addr";
+ specialName << getVariable() << "_addr";
setNameFn(getResult(), specialName.str());
}
@@ -2124,7 +2127,7 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
//===----------------------------------------------------------------------===//
LogicalResult spirv::ControlBarrierOp::verify() {
- return verifyMemorySemantics(getOperation(), memory_semantics());
+ return verifyMemorySemantics(getOperation(), getMemorySemantics());
}
//===----------------------------------------------------------------------===//
@@ -2208,9 +2211,9 @@ ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
}
void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
- printer << " \"" << stringifyExecutionModel(execution_model()) << "\" ";
- printer.printSymbolName(fn());
- auto interfaceVars = interface().getValue();
+ printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
+ printer.printSymbolName(getFn());
+ auto interfaceVars = getInterface().getValue();
if (!interfaceVars.empty()) {
printer << ", ";
llvm::interleaveComma(interfaceVars, printer);
@@ -2262,9 +2265,9 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
printer << " ";
- printer.printSymbolName(fn());
- printer << " \"" << stringifyExecutionMode(execution_mode()) << "\"";
- auto values = this->values();
+ printer.printSymbolName(getFn());
+ printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
+ auto values = this->getValues();
if (values.empty())
return;
printer << ", ";
@@ -2351,19 +2354,19 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
void spirv::FuncOp::print(OpAsmPrinter &printer) {
// Print function name, signature, and control.
printer << " ";
- printer.printSymbolName(sym_name());
+ printer.printSymbolName(getSymName());
auto fnType = getFunctionType();
function_interface_impl::printFunctionSignature(
printer, *this, fnType.getInputs(),
/*isVariadic=*/false, fnType.getResults());
- printer << " \"" << spirv::stringifyFunctionControl(function_control())
+ printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
<< "\"";
function_interface_impl::printFunctionAttributes(
printer, *this, fnType.getNumInputs(), fnType.getNumResults(),
{spirv::attributeName<spirv::FunctionControl>()});
// Print the body if this is not an external function.
- Region &body = this->body();
+ Region &body = this->getBody();
if (!body.empty()) {
printer << ' ';
printer.printRegion(body, /*printEntryBlockArgs=*/false,
@@ -2394,7 +2397,7 @@ LogicalResult spirv::FuncOp::verifyBody() {
"returns 1 value but enclosing function requires ")
<< fnType.getNumResults() << " results";
- auto retOperandType = retOp.value().getType();
+ auto retOperandType = retOp.getValue().getType();
auto fnResultType = fnType.getResult(0);
if (retOperandType != fnResultType)
return retOp.emitOpError(" return value's type (")
@@ -2424,7 +2427,7 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
// CallableOpInterface
Region *spirv::FuncOp::getCallableRegion() {
- return isExternal() ? nullptr : &body();
+ return isExternal() ? nullptr : &getBody();
}
// CallableOpInterface
@@ -2437,7 +2440,7 @@ ArrayRef<Type> spirv::FuncOp::getCallableResults() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::FunctionCallOp::verify() {
- auto fnName = calleeAttr();
+ auto fnName = getCalleeAttr();
auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
@@ -2490,7 +2493,7 @@ CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
}
Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
- return arguments();
+ return getArguments();
}
//===----------------------------------------------------------------------===//
@@ -2599,11 +2602,11 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
// Print variable name.
printer << ' ';
- printer.printSymbolName(sym_name());
+ printer.printSymbolName(getSymName());
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
// Print optional initializer
- if (auto initializer = this->initializer()) {
+ if (auto initializer = this->getInitializer()) {
printer << " " << kInitializerAttrName << '(';
printer.printSymbolName(*initializer);
printer << ')';
@@ -2612,7 +2615,7 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
elidedAttrs.push_back(kTypeAttrName);
printVariableDecorations(*this, printer, elidedAttrs);
- printer << " : " << type();
+ printer << " : " << getType();
}
LogicalResult spirv::GlobalVariableOp::verify() {
@@ -2649,11 +2652,11 @@ LogicalResult spirv::GlobalVariableOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::GroupBroadcastOp::verify() {
- spirv::Scope scope = execution_scope();
+ spirv::Scope scope = getExecutionScope();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
- if (auto localIdTy = localid().getType().dyn_cast<VectorType>())
+ if (auto localIdTy = getLocalid().getType().dyn_cast<VectorType>())
if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
return emitOpError("localid is a vector and can be with only "
" 2 or 3 components, actual number is ")
@@ -2667,7 +2670,7 @@ LogicalResult spirv::GroupBroadcastOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::GroupNonUniformBallotOp::verify() {
- spirv::Scope scope = execution_scope();
+ spirv::Scope scope = getExecutionScope();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
@@ -2679,7 +2682,7 @@ LogicalResult spirv::GroupNonUniformBallotOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::GroupNonUniformBroadcastOp::verify() {
- spirv::Scope scope = execution_scope();
+ spirv::Scope scope = getExecutionScope();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
@@ -2690,7 +2693,7 @@ LogicalResult spirv::GroupNonUniformBroadcastOp::verify() {
targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
if (targetEnv.getVersion() < spirv::Version::V_1_5) {
- auto *idOp = id().getDefiningOp();
+ auto *idOp = getId().getDefiningOp();
if (!idOp || !isa<spirv::ConstantOp, // for normal constant
spirv::ReferenceOfOp>(idOp)) // for spec constant
return emitOpError("id must be the result of a constant op");
@@ -2705,7 +2708,7 @@ LogicalResult spirv::GroupNonUniformBroadcastOp::verify() {
template <typename OpTy>
static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) {
- spirv::Scope scope = op.execution_scope();
+ spirv::Scope scope = op.getExecutionScope();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
@@ -2756,11 +2759,11 @@ ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
}
void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) {
- printer << " " << ptr() << " : " << getType();
+ printer << " " << getPtr() << " : " << getType();
}
LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
- if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
+ if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
return failure();
return success();
@@ -2795,11 +2798,12 @@ ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
}
void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
- printer << " " << ptr() << ", " << value() << " : " << value().getType();
+ printer << " " << getPtr() << ", " << getValue() << " : "
+ << getValue().getType();
}
LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
- if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
+ if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
return failure();
return success();
@@ -2810,7 +2814,7 @@ LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::GroupNonUniformElectOp::verify() {
- spirv::Scope scope = execution_scope();
+ spirv::Scope scope = getExecutionScope();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
@@ -2986,7 +2990,7 @@ LogicalResult spirv::IAddCarryOp::verify() {
if (resultType.getNumElements() != 2)
return emitOpError("expected result struct type containing two members");
- if (!llvm::all_equal({operand1().getType(), operand2().getType(),
+ if (!llvm::all_equal({getOperand1().getType(), getOperand2().getType(),
resultType.getElementType(0),
resultType.getElementType(1)}))
return emitOpError(
@@ -3035,7 +3039,7 @@ LogicalResult spirv::ISubBorrowOp::verify() {
if (resultType.getNumElements() != 2)
return emitOpError("expected result struct type containing two members");
- if (!llvm::all_equal({operand1().getType(), operand2().getType(),
+ if (!llvm::all_equal({getOperand1().getType(), getOperand2().getType(),
resultType.getElementType(0),
resultType.getElementType(1)}))
return emitOpError(
@@ -3111,8 +3115,8 @@ ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &result) {
void spirv::LoadOp::print(OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass(
- ptr().getType().cast<spirv::PointerType>().getStorageClass());
- printer << " \"" << sc << "\" " << ptr();
+ getPtr().getType().cast<spirv::PointerType>().getStorageClass());
+ printer << " \"" << sc << "\" " << getPtr();
printMemoryAccessAttribute(*this, printer, elidedAttrs);
@@ -3124,7 +3128,7 @@ LogicalResult spirv::LoadOp::verify() {
// SPIR-V spec : "Result Type is the type of the loaded object. It must be a
// type with fixed size; i.e., it cannot be, nor include, any
// OpTypeRuntimeArray types."
- if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value()))) {
+ if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
return failure();
}
return verifyMemoryAccessAttribute(*this);
@@ -3148,7 +3152,7 @@ ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &result) {
}
void spirv::LoopOp::print(OpAsmPrinter &printer) {
- auto control = loop_control();
+ auto control = getLoopControl();
if (control != spirv::LoopControl::None)
printer << " control(" << spirv::stringifyLoopControl(control) << ")";
printer << ' ';
@@ -3253,33 +3257,33 @@ LogicalResult spirv::LoopOp::verifyRegions() {
}
Block *spirv::LoopOp::getEntryBlock() {
- assert(!body().empty() && "op region should not be empty!");
- return &body().front();
+ assert(!getBody().empty() && "op region should not be empty!");
+ return &getBody().front();
}
Block *spirv::LoopOp::getHeaderBlock() {
- assert(!body().empty() && "op region should not be empty!");
+ assert(!getBody().empty() && "op region should not be empty!");
// The second block is the loop header block.
- return &*std::next(body().begin());
+ return &*std::next(getBody().begin());
}
Block *spirv::LoopOp::getContinueBlock() {
- assert(!body().empty() && "op region should not be empty!");
+ assert(!getBody().empty() && "op region should not be empty!");
// The second to last block is the loop continue block.
- return &*std::prev(body().end(), 2);
+ return &*std::prev(getBody().end(), 2);
}
Block *spirv::LoopOp::getMergeBlock() {
- assert(!body().empty() && "op region should not be empty!");
+ assert(!getBody().empty() && "op region should not be empty!");
// The last block is the loop merge block.
- return &body().back();
+ return &getBody().back();
}
void spirv::LoopOp::addEntryAndMergeBlock() {
- assert(body().empty() && "entry and merge block already exist");
- body().push_back(new Block());
+ assert(getBody().empty() && "entry and merge block already exist");
+ getBody().push_back(new Block());
auto *mergeBlock = new Block();
- body().push_back(mergeBlock);
+ getBody().push_back(mergeBlock);
OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
// Add a spv.mlir.merge op into the merge block.
@@ -3291,7 +3295,7 @@ void spirv::LoopOp::addEntryAndMergeBlock() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::MemoryBarrierOp::verify() {
- return verifyMemorySemantics(getOperation(), memory_semantics());
+ return verifyMemorySemantics(getOperation(), getMemorySemantics());
}
//===----------------------------------------------------------------------===//
@@ -3390,14 +3394,14 @@ void spirv::ModuleOp::print(OpAsmPrinter &printer) {
SmallVector<StringRef, 2> elidedAttrs;
- printer << " " << spirv::stringifyAddressingModel(addressing_model()) << " "
- << spirv::stringifyMemoryModel(memory_model());
+ printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
+ << spirv::stringifyMemoryModel(getMemoryModel());
auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
mlir::SymbolTable::getSymbolAttrName()});
- if (Optional<spirv::VerCapExtAttr> triple = vce_triple()) {
+ if (Optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
printer << " requires " << *triple;
elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
}
@@ -3421,12 +3425,12 @@ LogicalResult spirv::ModuleOp::verifyRegions() {
// duplicated in EntryPointOps. Also verify that the interface specified
// comes from globalVariables here to make this check cheaper.
if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
- auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.fn());
+ auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
if (!funcOp) {
return entryPointOp.emitError("function '")
- << entryPointOp.fn() << "' not found in 'spv.module'";
+ << entryPointOp.getFn() << "' not found in 'spv.module'";
}
- if (auto interface = entryPointOp.interface()) {
+ if (auto interface = entryPointOp.getInterface()) {
for (Attribute varRef : interface) {
auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
if (!varSymRef) {
@@ -3446,7 +3450,7 @@ LogicalResult spirv::ModuleOp::verifyRegions() {
}
auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
- funcOp, entryPointOp.execution_model());
+ funcOp, entryPointOp.getExecutionModel());
auto entryPtIt = entryPoints.find(key);
if (entryPtIt != entryPoints.end()) {
return entryPointOp.emitError("duplicate of a previous EntryPointOp");
@@ -3475,23 +3479,23 @@ LogicalResult spirv::ModuleOp::verifyRegions() {
LogicalResult spirv::ReferenceOfOp::verify() {
auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
- (*this)->getParentOp(), spec_constAttr());
+ (*this)->getParentOp(), getSpecConstAttr());
Type constType;
auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
if (specConstOp)
- constType = specConstOp.default_value().getType();
+ constType = specConstOp.getDefaultValue().getType();
auto specConstCompositeOp =
dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
if (specConstCompositeOp)
- constType = specConstCompositeOp.type();
+ constType = specConstCompositeOp.getType();
if (!specConstOp && !specConstCompositeOp)
return emitOpError(
"expected spv.SpecConstant or spv.SpecConstantComposite symbol");
- if (reference().getType() != constType)
+ if (getReference().getType() != constType)
return emitOpError("result type mismatch with the referenced "
"specialization constant's type");
@@ -3521,8 +3525,8 @@ LogicalResult spirv::ReturnValueOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::SelectOp::verify() {
- if (auto conditionTy = condition().getType().dyn_cast<VectorType>()) {
- auto resultVectorTy = result().getType().dyn_cast<VectorType>();
+ if (auto conditionTy = getCondition().getType().dyn_cast<VectorType>()) {
+ auto resultVectorTy = getResult().getType().dyn_cast<VectorType>();
if (!resultVectorTy) {
return emitOpError("result expected to be of vector type when "
"condition is of vector type");
@@ -3548,7 +3552,7 @@ ParseResult spirv::SelectionOp::parse(OpAsmParser &parser,
}
void spirv::SelectionOp::print(OpAsmPrinter &printer) {
- auto control = selection_control();
+ auto control = getSelectionControl();
if (control != spirv::SelectionControl::None)
printer << " control(" << spirv::stringifySelectionControl(control) << ")";
printer << ' ';
@@ -3598,21 +3602,21 @@ LogicalResult spirv::SelectionOp::verifyRegions() {
}
Block *spirv::SelectionOp::getHeaderBlock() {
- assert(!body().empty() && "op region should not be empty!");
+ assert(!getBody().empty() && "op region should not be empty!");
// The first block is the loop header block.
- return &body().front();
+ return &getBody().front();
}
Block *spirv::SelectionOp::getMergeBlock() {
- assert(!body().empty() && "op region should not be empty!");
+ assert(!getBody().empty() && "op region should not be empty!");
// The last block is the loop merge block.
- return &body().back();
+ return &getBody().back();
}
void spirv::SelectionOp::addMergeBlock() {
- assert(body().empty() && "entry and merge block already exist");
+ assert(getBody().empty() && "entry and merge block already exist");
auto *mergeBlock = new Block();
- body().push_back(mergeBlock);
+ getBody().push_back(mergeBlock);
OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
// Add a spv.mlir.merge op into the merge block.
@@ -3682,10 +3686,10 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
printer << ' ';
- printer.printSymbolName(sym_name());
+ printer.printSymbolName(getSymName());
if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
- printer << " = " << default_value();
+ printer << " = " << getDefaultValue();
}
LogicalResult spirv::SpecConstantOp::verify() {
@@ -3693,7 +3697,7 @@ LogicalResult spirv::SpecConstantOp::verify() {
if (specID.getValue().isNegative())
return emitOpError("SpecId cannot be negative");
- auto value = default_value();
+ auto value = getDefaultValue();
if (value.isa<IntegerAttr, FloatAttr>()) {
// Make sure bitwidth is allowed.
if (!value.getType().isa<spirv::SPIRVType>())
@@ -3732,19 +3736,19 @@ ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &result) {
void spirv::StoreOp::print(OpAsmPrinter &printer) {
SmallVector<StringRef, 4> elidedAttrs;
StringRef sc = stringifyStorageClass(
- ptr().getType().cast<spirv::PointerType>().getStorageClass());
- printer << " \"" << sc << "\" " << ptr() << ", " << value();
+ getPtr().getType().cast<spirv::PointerType>().getStorageClass());
+ printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
printMemoryAccessAttribute(*this, printer, elidedAttrs);
- printer << " : " << value().getType();
+ printer << " : " << getValue().getType();
printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
}
LogicalResult spirv::StoreOp::verify() {
// SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
// OpTypePointer whose Type operand is the same as the type of Object."
- if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value())))
+ if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
return failure();
return verifyMemoryAccessAttribute(*this);
}
@@ -3819,7 +3823,7 @@ void spirv::VariableOp::print(OpAsmPrinter &printer) {
spirv::attributeName<spirv::StorageClass>()};
// Print optional initializer
if (getNumOperands() != 0)
- printer << " init(" << initializer() << ")";
+ printer << " init(" << getInitializer() << ")";
printVariableDecorations(*this, printer, elidedAttrs);
printer << " : " << getType();
@@ -3829,14 +3833,14 @@ LogicalResult spirv::VariableOp::verify() {
// SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
// object. It cannot be Generic. It must be the same as the Storage Class
// operand of the Result Type."
- if (storage_class() != spirv::StorageClass::Function) {
+ if (getStorageClass() != spirv::StorageClass::Function) {
return emitOpError(
"can only be used to model function-level variables. Use "
"spv.GlobalVariable for module-level variables.");
}
- auto pointerType = pointer().getType().cast<spirv::PointerType>();
- if (storage_class() != pointerType.getStorageClass())
+ auto pointerType = getPointer().getType().cast<spirv::PointerType>();
+ if (getStorageClass() != pointerType.getStorageClass())
return emitOpError(
"storage class must match result pointer's storage class");
@@ -3877,17 +3881,17 @@ LogicalResult spirv::VectorShuffleOp::verify() {
VectorType resultType = getType().cast<VectorType>();
size_t numResultElements = resultType.getNumElements();
- if (numResultElements != components().size())
+ if (numResultElements != getComponents().size())
return emitOpError("result type element count (")
<< numResultElements
<< ") mismatch with the number of component selectors ("
- << components().size() << ")";
+ << getComponents().size() << ")";
size_t totalSrcElements =
- vector1().getType().cast<VectorType>().getNumElements() +
- vector2().getType().cast<VectorType>().getNumElements();
+ getVector1().getType().cast<VectorType>().getNumElements() +
+ getVector2().getType().cast<VectorType>().getNumElements();
- for (const auto &selector : components().getAsValueRange<IntegerAttr>()) {
+ for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
uint32_t index = selector.getZExtValue();
if (index >= totalSrcElements &&
index != std::numeric_limits<uint32_t>().max())
@@ -3925,11 +3929,12 @@ ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
}
void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
- printer << " " << pointer() << ", " << stride() << ", " << columnmajor();
+ printer << " " << getPointer() << ", " << getStride() << ", "
+ << getColumnmajor();
// Print optional memory access attribute.
- if (auto memAccess = memory_access())
+ if (auto memAccess = getMemoryAccess())
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
- printer << " : " << pointer().getType() << " as " << getType();
+ printer << " : " << getPointer().getType() << " as " << getType();
}
static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
@@ -3952,8 +3957,8 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
}
LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() {
- return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
- result().getType());
+ return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
+ getResult().getType());
}
//===----------------------------------------------------------------------===//
@@ -3983,17 +3988,17 @@ ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
}
void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
- printer << " " << pointer() << ", " << object() << ", " << stride() << ", "
- << columnmajor();
+ printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
+ << ", " << getColumnmajor();
// Print optional memory access attribute.
- if (auto memAccess = memory_access())
+ if (auto memAccess = getMemoryAccess())
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
- printer << " : " << pointer().getType() << ", " << getOperand(1).getType();
+ printer << " : " << getPointer().getType() << ", " << getOperand(1).getType();
}
LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
- return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
- object().getType());
+ return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
+ getObject().getType());
}
//===----------------------------------------------------------------------===//
@@ -4002,12 +4007,12 @@ LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
static LogicalResult
verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
- if (op.c().getType() != op.result().getType())
+ if (op.getC().getType() != op.getResult().getType())
return op.emitOpError("result and third operand must have the same type");
- auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
- auto typeB = op.b().getType().cast<spirv::CooperativeMatrixNVType>();
- auto typeC = op.c().getType().cast<spirv::CooperativeMatrixNVType>();
- auto typeR = op.result().getType().cast<spirv::CooperativeMatrixNVType>();
+ auto typeA = op.getA().getType().cast<spirv::CooperativeMatrixNVType>();
+ auto typeB = op.getB().getType().cast<spirv::CooperativeMatrixNVType>();
+ auto typeC = op.getC().getType().cast<spirv::CooperativeMatrixNVType>();
+ auto typeR = op.getResult().getType().cast<spirv::CooperativeMatrixNVType>();
if (typeA.getRows() != typeR.getRows() ||
typeA.getColumns() != typeB.getRows() ||
typeB.getColumns() != typeR.getColumns())
@@ -4050,8 +4055,8 @@ verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
//===----------------------------------------------------------------------===//
LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
- return verifyPointerAndJointMatrixType(*this, pointer().getType(),
- result().getType());
+ return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
+ getResult().getType());
}
//===----------------------------------------------------------------------===//
@@ -4059,8 +4064,8 @@ LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
- return verifyPointerAndJointMatrixType(*this, pointer().getType(),
- object().getType());
+ return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
+ getObject().getType());
}
//===----------------------------------------------------------------------===//
@@ -4068,12 +4073,12 @@ LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
//===----------------------------------------------------------------------===//
static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
- if (op.c().getType() != op.result().getType())
+ if (op.getC().getType() != op.getResult().getType())
return op.emitOpError("result and third operand must have the same type");
- auto typeA = op.a().getType().cast<spirv::JointMatrixINTELType>();
- auto typeB = op.b().getType().cast<spirv::JointMatrixINTELType>();
- auto typeC = op.c().getType().cast<spirv::JointMatrixINTELType>();
- auto typeR = op.result().getType().cast<spirv::JointMatrixINTELType>();
+ auto typeA = op.getA().getType().cast<spirv::JointMatrixINTELType>();
+ auto typeB = op.getB().getType().cast<spirv::JointMatrixINTELType>();
+ auto typeC = op.getC().getType().cast<spirv::JointMatrixINTELType>();
+ auto typeR = op.getResult().getType().cast<spirv::JointMatrixINTELType>();
if (typeA.getRows() != typeR.getRows() ||
typeA.getColumns() != typeB.getRows() ||
typeB.getColumns() != typeR.getColumns())
@@ -4100,11 +4105,11 @@ LogicalResult spirv::MatrixTimesScalarOp::verify() {
// We already checked that result and matrix are both of matrix type in the
// auto-generated verify method.
- auto inputMatrix = matrix().getType().cast<spirv::MatrixType>();
- auto resultMatrix = result().getType().cast<spirv::MatrixType>();
+ auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
+ auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
// Check that the scalar type is the same as the matrix element type.
- if (scalar().getType() != inputMatrix.getElementType())
+ if (getScalar().getType() != inputMatrix.getElementType())
return emitError("input matrix components' type and scaling value must "
"have the same type");
@@ -4137,22 +4142,23 @@ void spirv::CopyMemoryOp::print(OpAsmPrinter &printer) {
printer << ' ';
StringRef targetStorageClass = stringifyStorageClass(
- target().getType().cast<spirv::PointerType>().getStorageClass());
- printer << " \"" << targetStorageClass << "\" " << target() << ", ";
+ getTarget().getType().cast<spirv::PointerType>().getStorageClass());
+ printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
StringRef sourceStorageClass = stringifyStorageClass(
- source().getType().cast<spirv::PointerType>().getStorageClass());
- printer << " \"" << sourceStorageClass << "\" " << source();
+ getSource().getType().cast<spirv::PointerType>().getStorageClass());
+ printer << " \"" << sourceStorageClass << "\" " << getSource();
SmallVector<StringRef, 4> elidedAttrs;
printMemoryAccessAttribute(*this, printer, elidedAttrs);
printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
- source_memory_access(), source_alignment());
+ getSourceMemoryAccess(),
+ getSourceAlignment());
printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
Type pointeeType =
- target().getType().cast<spirv::PointerType>().getPointeeType();
+ getTarget().getType().cast<spirv::PointerType>().getPointeeType();
printer << " : " << pointeeType;
}
@@ -4200,10 +4206,10 @@ ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser,
LogicalResult spirv::CopyMemoryOp::verify() {
Type targetType =
- target().getType().cast<spirv::PointerType>().getPointeeType();
+ getTarget().getType().cast<spirv::PointerType>().getPointeeType();
Type sourceType =
- source().getType().cast<spirv::PointerType>().getPointeeType();
+ getSource().getType().cast<spirv::PointerType>().getPointeeType();
if (targetType != sourceType)
return emitOpError("both operands must be pointers to the same type");
@@ -4227,8 +4233,8 @@ LogicalResult spirv::CopyMemoryOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::TransposeOp::verify() {
- auto inputMatrix = matrix().getType().cast<spirv::MatrixType>();
- auto resultMatrix = result().getType().cast<spirv::MatrixType>();
+ auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
+ auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
// Verify that the input and output matrices have correct shapes.
if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
@@ -4252,9 +4258,9 @@ LogicalResult spirv::TransposeOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::MatrixTimesMatrixOp::verify() {
- auto leftMatrix = leftmatrix().getType().cast<spirv::MatrixType>();
- auto rightMatrix = rightmatrix().getType().cast<spirv::MatrixType>();
- auto resultMatrix = result().getType().cast<spirv::MatrixType>();
+ auto leftMatrix = getLeftmatrix().getType().cast<spirv::MatrixType>();
+ auto rightMatrix = getRightmatrix().getType().cast<spirv::MatrixType>();
+ auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
// left matrix columns' count and right matrix rows' count must be equal
if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
@@ -4329,23 +4335,23 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
printer << " ";
- printer.printSymbolName(sym_name());
+ printer.printSymbolName(getSymName());
printer << " (";
- auto constituents = this->constituents().getValue();
+ auto constituents = this->getConstituents().getValue();
if (!constituents.empty())
llvm::interleaveComma(constituents, printer);
- printer << ") : " << type();
+ printer << ") : " << getType();
}
LogicalResult spirv::SpecConstantCompositeOp::verify() {
- auto cType = type().dyn_cast<spirv::CompositeType>();
- auto constituents = this->constituents().getValue();
+ auto cType = getType().dyn_cast<spirv::CompositeType>();
+ auto constituents = this->getConstituents().getValue();
if (!cType)
return emitError("result type must be a composite type, but provided ")
- << type();
+ << getType();
if (cType.isa<spirv::CooperativeMatrixNVType>())
return emitError("unsupported composite type ") << cType;
@@ -4363,11 +4369,11 @@ LogicalResult spirv::SpecConstantCompositeOp::verify() {
dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
(*this)->getParentOp(), constituent.getAttr()));
- if (constituentSpecConstOp.default_value().getType() !=
+ if (constituentSpecConstOp.getDefaultValue().getType() !=
cType.getElementType(index))
return emitError("has incorrect types of operands: expected ")
<< cType.getElementType(index) << ", but provided "
- << constituentSpecConstOp.default_value().getType();
+ << constituentSpecConstOp.getDefaultValue().getType();
}
return success();
@@ -4406,7 +4412,7 @@ ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
printer << " wraps ";
- printer.printGenericOp(&body().front().front());
+ printer.printGenericOp(&getBody().front().front());
}
LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
@@ -4434,7 +4440,8 @@ LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::GLFrexpStructOp::verify() {
- spirv::StructType structTy = result().getType().dyn_cast<spirv::StructType>();
+ spirv::StructType structTy =
+ getResult().getType().dyn_cast<spirv::StructType>();
if (structTy.getNumElements() != 2)
return emitError("result type must be a struct type with two memebers");
@@ -4444,7 +4451,7 @@ LogicalResult spirv::GLFrexpStructOp::verify() {
VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>();
IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>();
- Type operandTy = operand().getType();
+ Type operandTy = getOperand().getType();
VectorType operandVecTy = operandTy.dyn_cast<VectorType>();
FloatType operandFTy = operandTy.dyn_cast<FloatType>();
@@ -4480,8 +4487,8 @@ LogicalResult spirv::GLFrexpStructOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::GLLdexpOp::verify() {
- Type significandType = x().getType();
- Type exponentType = exp().getType();
+ Type significandType = getX().getType();
+ Type exponentType = getExp().getType();
if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>())
return emitOpError("operands must both be scalars or vectors");
@@ -4503,9 +4510,9 @@ LogicalResult spirv::GLLdexpOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::ImageDrefGatherOp::verify() {
- VectorType resultType = result().getType().cast<VectorType>();
+ VectorType resultType = getResult().getType().cast<VectorType>();
auto sampledImageType =
- sampledimage().getType().cast<spirv::SampledImageType>();
+ getSampledimage().getType().cast<spirv::SampledImageType>();
auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>();
if (resultType.getNumElements() != 4)
@@ -4530,8 +4537,8 @@ LogicalResult spirv::ImageDrefGatherOp::verify() {
if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
return emitOpError("the MS operand of the underlying image type must be 0");
- spirv::ImageOperandsAttr attr = imageoperandsAttr();
- auto operandArguments = operand_arguments();
+ spirv::ImageOperandsAttr attr = getImageoperandsAttr();
+ auto operandArguments = getOperandArguments();
return verifyImageOperands(*this, attr, operandArguments);
}
@@ -4565,8 +4572,8 @@ LogicalResult spirv::ShiftRightLogicalOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::ImageQuerySizeOp::verify() {
- spirv::ImageType imageType = image().getType().cast<spirv::ImageType>();
- Type resultType = result().getType();
+ spirv::ImageType imageType = getImage().getType().cast<spirv::ImageType>();
+ Type resultType = getResult().getType();
spirv::Dim dim = imageType.getDim();
spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
@@ -4668,9 +4675,9 @@ static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
template <typename Op>
static auto concatElemAndIndices(Op op) {
- SmallVector<Value> ret(op.indices().size() + 1);
- ret[0] = op.element();
- llvm::copy(op.indices(), ret.begin() + 1);
+ SmallVector<Value> ret(op.getIndices().size() + 1);
+ ret[0] = op.getElement();
+ llvm::copy(op.getIndices(), ret.begin() + 1);
return ret;
}
@@ -4698,7 +4705,7 @@ void spirv::InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
}
LogicalResult spirv::InBoundsPtrAccessChainOp::verify() {
- return verifyAccessChain(*this, indices());
+ return verifyAccessChain(*this, getIndices());
}
//===----------------------------------------------------------------------===//
@@ -4724,7 +4731,7 @@ void spirv::PtrAccessChainOp::print(OpAsmPrinter &printer) {
}
LogicalResult spirv::PtrAccessChainOp::verify() {
- return verifyAccessChain(*this, indices());
+ return verifyAccessChain(*this, getIndices());
}
//===----------------------------------------------------------------------===//
@@ -4732,10 +4739,10 @@ LogicalResult spirv::PtrAccessChainOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult spirv::VectorTimesScalarOp::verify() {
- if (vector().getType() != getType())
+ if (getVector().getType() != getType())
return emitOpError("vector operand and result type mismatch");
auto scalarType = getType().cast<VectorType>().getElementType();
- if (scalar().getType() != scalarType)
+ if (getScalar().getType() != scalarType)
return emitOpError("scalar operand and result element type match");
return success();
}
diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
index 80e8f3c185f26..888e1ca8456c6 100644
--- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
+++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
@@ -94,16 +94,16 @@ OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
return nullptr;
spirv::ModuleOp firstModule = inputModules.front();
- auto addressingModel = firstModule.addressing_model();
- auto memoryModel = firstModule.memory_model();
- auto vceTriple = firstModule.vce_triple();
+ auto addressingModel = firstModule.getAddressingModel();
+ auto memoryModel = firstModule.getMemoryModel();
+ auto vceTriple = firstModule.getVceTriple();
// First check whether there are conflicts between addressing/memory model.
// Return early if so.
for (auto module : inputModules) {
- if (module.addressing_model() != addressingModel ||
- module.memory_model() != memoryModel ||
- module.vce_triple() != vceTriple) {
+ if (module.getAddressingModel() != addressingModel ||
+ module.getMemoryModel() != memoryModel ||
+ module.getVceTriple() != vceTriple) {
module.emitError("input modules
diff er in addressing model, memory "
"model, and/or VCE triple");
return nullptr;
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
index 2e8f70c7e434e..06649bad2440a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
@@ -40,7 +40,7 @@ class SPIRVGlobalVariableOpLayoutInfoDecoration
PatternRewriter &rewriter) const override {
SmallVector<NamedAttribute, 4> globalVarAttrs;
- auto ptrType = op.type().cast<spirv::PointerType>();
+ auto ptrType = op.getType().cast<spirv::PointerType>();
auto structType = VulkanLayoutUtils::decorateType(
ptrType.getPointeeType().cast<spirv::StructType>());
@@ -71,11 +71,11 @@ class SPIRVAddressOfOpLayoutInfoDecoration
LogicalResult matchAndRewrite(spirv::AddressOfOp op,
PatternRewriter &rewriter) const override {
auto spirvModule = op->getParentOfType<spirv::ModuleOp>();
- auto varName = op.variableAttr();
+ auto varName = op.getVariableAttr();
auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
- op, varOp.type(), SymbolRefAttr::get(varName.getAttr()));
+ op, varOp.getType(), SymbolRefAttr::get(varName.getAttr()));
return success();
}
};
@@ -121,12 +121,12 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
target.addLegalOp<func::FuncOp>();
target.addDynamicallyLegalOp<spirv::GlobalVariableOp>(
[](spirv::GlobalVariableOp op) {
- return VulkanLayoutUtils::isLegalType(op.type());
+ return VulkanLayoutUtils::isLegalType(op.getType());
});
// Change the type for the direct users.
target.addDynamicallyLegalOp<spirv::AddressOfOp>([](spirv::AddressOfOp op) {
- return VulkanLayoutUtils::isLegalType(op.pointer().getType());
+ return VulkanLayoutUtils::isLegalType(op.getPointer().getType());
});
// Change the type for the indirect users.
@@ -134,7 +134,8 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
spirv::StoreOp>([&](Operation *op) {
for (Value operand : op->getOperands()) {
auto addrOp = operand.getDefiningOp<spirv::AddressOfOp>();
- if (addrOp && !VulkanLayoutUtils::isLegalType(addrOp.pointer().getType()))
+ if (addrOp &&
+ !VulkanLayoutUtils::isLegalType(addrOp.getPointer().getType()))
return false;
}
return true;
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index ad296e32fd29c..8ed2d0095a54f 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -88,13 +88,13 @@ getInterfaceVariables(spirv::FuncOp funcOp,
// instructions in this function.
funcOp.walk([&](spirv::AddressOfOp addressOfOp) {
auto var =
- module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
+ module.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.getVariable());
// TODO: Per SPIR-V spec: "Before version 1.4, the interface’s
// storage classes are limited to the Input and Output storage classes.
// Starting with version 1.4, the interface’s storage classes are all
// storage classes used in declaring all global variables referenced by the
// entry point’s call tree." We should consider the target environment here.
- switch (var.type().cast<spirv::PointerType>().getStorageClass()) {
+ switch (var.getType().cast<spirv::PointerType>().getStorageClass()) {
case spirv::StorageClass::Input:
case spirv::StorageClass::Output:
interfaceVarSet.insert(var.getOperation());
@@ -105,7 +105,7 @@ getInterfaceVariables(spirv::FuncOp funcOp,
});
for (auto &var : interfaceVarSet) {
interfaceVars.push_back(SymbolRefAttr::get(
- funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).sym_name()));
+ funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).getSymName()));
}
return success();
}
@@ -223,7 +223,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
auto zero =
spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), rewriter);
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
- funcOp.getLoc(), replacement, zero.constant());
+ funcOp.getLoc(), replacement, zero.getConstant());
replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr);
}
signatureConverter.remapInput(argType.index(), replacement);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
index 95d21f35e0d08..26f9579c5f061 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
@@ -63,7 +63,7 @@ void RewriteInsertsPass::runOnOperation() {
SmallVector<Value, 4> operands;
// Collect inserted objects.
for (auto insertionOp : insertions)
- operands.push_back(insertionOp.object());
+ operands.push_back(insertionOp.getObject());
OpBuilder builder(lastCompositeInsertOp);
auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>(
@@ -84,11 +84,13 @@ void RewriteInsertsPass::runOnOperation() {
LogicalResult RewriteInsertsPass::collectInsertionChain(
spirv::CompositeInsertOp op,
SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
- auto indicesArrayAttr = op.indices().cast<ArrayAttr>();
+ auto indicesArrayAttr = op.getIndices().cast<ArrayAttr>();
// TODO: handle nested composite object.
if (indicesArrayAttr.size() == 1) {
- auto numElements =
- op.composite().getType().cast<spirv::CompositeType>().getNumElements();
+ auto numElements = op.getComposite()
+ .getType()
+ .cast<spirv::CompositeType>()
+ .getNumElements();
auto index = indicesArrayAttr[0].cast<IntegerAttr>().getInt();
// Need a last index to collect a sequential chain.
@@ -102,12 +104,12 @@ LogicalResult RewriteInsertsPass::collectInsertionChain(
if (index == 0)
return success();
- op = op.composite().getDefiningOp<spirv::CompositeInsertOp>();
+ op = op.getComposite().getDefiningOp<spirv::CompositeInsertOp>();
if (!op)
return failure();
--index;
- indicesArrayAttr = op.indices().cast<ArrayAttr>();
+ indicesArrayAttr = op.getIndices().cast<ArrayAttr>();
if ((indicesArrayAttr.size() != 1) ||
(indicesArrayAttr[0].cast<IntegerAttr>().getInt() != index))
return failure();
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index b56b8c02943d6..c52d484ea57a8 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -642,7 +642,7 @@ static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
unsigned elementCount) {
for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
- auto ptrType = varOp.type().dyn_cast<spirv::PointerType>();
+ auto ptrType = varOp.getType().dyn_cast<spirv::PointerType>();
if (!ptrType)
continue;
@@ -874,7 +874,7 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
// Special treatment for global variables, whose type requirements are
// conveyed by type attributes.
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
- valueTypes.push_back(globalVar.type());
+ valueTypes.push_back(globalVar.getType());
// Make sure the op's operands/results use types that are allowed by the
// target environment.
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index 140e316196491..397c330b954ff 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -51,8 +51,8 @@ static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
AliasedResourceMap aliasedResources;
moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
if (varOp->getAttrOfType<UnitAttr>("aliased")) {
- Optional<uint32_t> set = varOp.descriptor_set();
- Optional<uint32_t> binding = varOp.binding();
+ Optional<uint32_t> set = varOp.getDescriptorSet();
+ Optional<uint32_t> binding = varOp.getBinding();
if (set && binding)
aliasedResources[{*set, *binding}].push_back(varOp);
}
@@ -222,16 +222,16 @@ bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
}
if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
- auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable());
+ auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable());
return shouldUnify(varOp);
}
if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
- return shouldUnify(acOp.base_ptr().getDefiningOp());
+ return shouldUnify(acOp.getBasePtr().getDefiningOp());
if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
- return shouldUnify(loadOp.ptr().getDefiningOp());
+ return shouldUnify(loadOp.getPtr().getDefiningOp());
if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
- return shouldUnify(storeOp.ptr().getDefiningOp());
+ return shouldUnify(storeOp.getPtr().getDefiningOp());
return false;
}
@@ -265,7 +265,7 @@ void ResourceAliasAnalysis::recordIfUnifiable(
// Collect the element types for all resources in the current set.
SmallVector<spirv::SPIRVType> elementTypes;
for (spirv::GlobalVariableOp resource : resources) {
- Type elementType = getRuntimeArrayElementType(resource.type());
+ Type elementType = getRuntimeArrayElementType(resource.getType());
if (!elementType)
return; // Unexpected resource variable type.
@@ -326,7 +326,7 @@ struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> {
// Rewrite the AddressOf op to get the address of the canoncical resource.
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
auto srcVarOp = cast<spirv::GlobalVariableOp>(
- SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
+ SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
return success();
@@ -339,13 +339,13 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
LogicalResult
matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto addressOp = acOp.base_ptr().getDefiningOp<spirv::AddressOfOp>();
+ auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>();
if (!addressOp)
return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
auto srcVarOp = cast<spirv::GlobalVariableOp>(
- SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
+ SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
@@ -356,7 +356,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
// We have the same bitwidth for source and destination element types.
// Thie indices keep the same.
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
- acOp, adaptor.base_ptr(), adaptor.indices());
+ acOp, adaptor.getBasePtr(), adaptor.getIndices());
return success();
}
@@ -375,7 +375,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
auto ratioValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(ratio));
- auto indices = llvm::to_vector<4>(acOp.indices());
+ auto indices = llvm::to_vector<4>(acOp.getIndices());
Value oldIndex = indices.back();
indices.back() =
rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
@@ -383,7 +383,7 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
- acOp, adaptor.base_ptr(), indices);
+ acOp, adaptor.getBasePtr(), indices);
return success();
}
@@ -399,13 +399,13 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
auto ratioValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(ratio));
- auto indices = llvm::to_vector<4>(acOp.indices());
+ auto indices = llvm::to_vector<4>(acOp.getIndices());
Value oldIndex = indices.back();
indices.back() =
rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue);
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
- acOp, adaptor.base_ptr(), indices);
+ acOp, adaptor.getBasePtr(), indices);
return success();
}
@@ -420,13 +420,13 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
LogicalResult
matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto srcPtrType = loadOp.ptr().getType().cast<spirv::PointerType>();
+ auto srcPtrType = loadOp.getPtr().getType().cast<spirv::PointerType>();
auto srcElemType = srcPtrType.getPointeeType().cast<spirv::SPIRVType>();
- auto dstPtrType = adaptor.ptr().getType().cast<spirv::PointerType>();
+ auto dstPtrType = adaptor.getPtr().getType().cast<spirv::PointerType>();
auto dstElemType = dstPtrType.getPointeeType().cast<spirv::SPIRVType>();
Location loc = loadOp.getLoc();
- auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.ptr());
+ auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr());
if (srcElemType == dstElemType) {
rewriter.replaceOp(loadOp, newLoadOp->getResults());
return success();
@@ -434,7 +434,7 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
- newLoadOp.value());
+ newLoadOp.getValue());
rewriter.replaceOp(loadOp, castOp->getResults());
return success();
@@ -457,19 +457,19 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
components.reserve(ratio);
components.push_back(newLoadOp);
- auto acOp = adaptor.ptr().getDefiningOp<spirv::AccessChainOp>();
+ auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>();
if (!acOp)
return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain");
auto i32Type = rewriter.getI32Type();
Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
- auto indices = llvm::to_vector<4>(acOp.indices());
+ auto indices = llvm::to_vector<4>(acOp.getIndices());
for (int i = 1; i < ratio; ++i) {
// Load all subsequent components belonging to this element.
indices.back() = rewriter.create<spirv::IAddOp>(
loc, i32Type, indices.back(), oneValue);
auto componentAcOp = rewriter.create<spirv::AccessChainOp>(
- loc, acOp.base_ptr(), indices);
+ loc, acOp.getBasePtr(), indices);
// Assuming little endian, this reads lower-ordered bits of the number
// to lower-numbered components of the vector.
components.push_back(
@@ -504,19 +504,19 @@ struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcElemType =
- storeOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
+ storeOp.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
auto dstElemType =
- adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
+ adaptor.getPtr().getType().cast<spirv::PointerType>().getPointeeType();
if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
return rewriter.notifyMatchFailure(storeOp, "not scalar type");
if (!areSameBitwidthScalarType(srcElemType, dstElemType))
return rewriter.notifyMatchFailure(storeOp, "
diff erent bitwidth");
Location loc = storeOp.getLoc();
- Value value = adaptor.value();
+ Value value = adaptor.getValue();
if (srcElemType != dstElemType)
value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
- rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.ptr(), value,
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.getPtr(), value,
storeOp->getAttrs());
return success();
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index 9cb79f80bab06..d7cf3f5fcff5d 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -151,7 +151,7 @@ void UpdateVCEPass::runOnOperation() {
// Special treatment for global variables, whose type requirements are
// conveyed by type attributes.
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
- valueTypes.push_back(globalVar.type());
+ valueTypes.push_back(globalVar.getType());
// Requirements from values' types
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 5fc3000a5bc80..2b81e91ec7691 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -46,20 +46,20 @@ Value spirv::Deserializer::getValue(uint32_t id) {
}
if (auto varOp = getGlobalVariable(id)) {
auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
- unknownLoc, varOp.type(), SymbolRefAttr::get(varOp.getOperation()));
- return addressOfOp.pointer();
+ unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
+ return addressOfOp.getPointer();
}
if (auto constOp = getSpecConstant(id)) {
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
- unknownLoc, constOp.default_value().getType(),
+ unknownLoc, constOp.getDefaultValue().getType(),
SymbolRefAttr::get(constOp.getOperation()));
- return referenceOfOp.reference();
+ return referenceOfOp.getReference();
}
if (auto constCompositeOp = getSpecConstantComposite(id)) {
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
- unknownLoc, constCompositeOp.type(),
+ unknownLoc, constCompositeOp.getType(),
SymbolRefAttr::get(constCompositeOp.getOperation()));
- return referenceOfOp.reference();
+ return referenceOfOp.getReference();
}
if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
return materializeSpecConstantOperation(
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 84d8d9caf202f..604c71c9808d9 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1414,7 +1414,7 @@ Value spirv::Deserializer::materializeSpecConstantOperation(
auto specConstOperationOp =
opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
- Region &body = specConstOperationOp.body();
+ Region &body = specConstOperationOp.getBody();
// Move the new block into SpecConstantOperation's body.
body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
Region::iterator(enclosedBlock));
@@ -1983,17 +1983,17 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
assert((branchCondOp.getTrueBlock() == target ||
branchCondOp.getFalseBlock() == target) &&
"expected target to be either the true or false target");
- if (target == branchCondOp.trueTarget())
+ if (target == branchCondOp.getTrueTarget())
opBuilder.create<spirv::BranchConditionalOp>(
- branchCondOp.getLoc(), branchCondOp.condition(), blockArgs,
+ branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
branchCondOp.getFalseBlockArguments(),
- branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(),
- branchCondOp.falseTarget());
+ branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
+ branchCondOp.getFalseTarget());
else
opBuilder.create<spirv::BranchConditionalOp>(
- branchCondOp.getLoc(), branchCondOp.condition(),
+ branchCondOp.getLoc(), branchCondOp.getCondition(),
branchCondOp.getTrueBlockArguments(), blockArgs,
- branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(),
+ branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
branchCondOp.getFalseBlock());
branchCondOp.erase();
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
index 7fa313f35f770..80d1bb85611db 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
@@ -24,7 +24,7 @@ namespace mlir {
LogicalResult spirv::serialize(spirv::ModuleOp module,
SmallVectorImpl<uint32_t> &binary,
const SerializationOptions &options) {
- if (!module.vce_triple())
+ if (!module.getVceTriple())
return module.emitError(
"module must have 'vce_triple' attribute to be serializeable");
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index ada1483757742..6197664e1e9bf 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -58,7 +58,8 @@ visitInPrettyBlockOrder(Block *headerBlock,
namespace mlir {
namespace spirv {
LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
- if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) {
+ if (auto resultID =
+ prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
valueIDMap[op.getResult()] = resultID;
return success();
}
@@ -66,7 +67,7 @@ LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
}
LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
- if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(),
+ if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
/*isSpec=*/true)) {
// Emit the OpDecorate instruction for SpecId.
if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
@@ -75,8 +76,8 @@ LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
return failure();
}
- specConstIDMap[op.sym_name()] = resultID;
- return processName(resultID, op.sym_name());
+ specConstIDMap[op.getSymName()] = resultID;
+ return processName(resultID, op.getSymName());
}
return failure();
}
@@ -84,7 +85,7 @@ LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
LogicalResult
Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
uint32_t typeID = 0;
- if (failed(processType(op.getLoc(), op.type(), typeID))) {
+ if (failed(processType(op.getLoc(), op.getType(), typeID))) {
return failure();
}
@@ -94,7 +95,7 @@ Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
operands.push_back(typeID);
operands.push_back(resultID);
- auto constituents = op.constituents();
+ auto constituents = op.getConstituents();
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
@@ -112,9 +113,9 @@ Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
encodeInstructionInto(typesGlobalValues,
spirv::Opcode::OpSpecConstantComposite, operands);
- specConstIDMap[op.sym_name()] = resultID;
+ specConstIDMap[op.getSymName()] = resultID;
- return processName(resultID, op.sym_name());
+ return processName(resultID, op.getSymName());
}
LogicalResult
@@ -199,7 +200,7 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
operands.push_back(resTypeID);
auto funcID = getOrCreateFunctionID(op.getName());
operands.push_back(funcID);
- operands.push_back(static_cast<uint32_t>(op.function_control()));
+ operands.push_back(static_cast<uint32_t>(op.getFunctionControl()));
operands.push_back(fnTypeID);
encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
@@ -310,7 +311,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
// Get TypeID.
uint32_t resultTypeID = 0;
SmallVector<StringRef, 4> elidedAttrs;
- if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
+ if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
return failure();
}
@@ -320,7 +321,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
auto resultID = getNextID();
// Encode the name.
- auto varName = varOp.sym_name();
+ auto varName = varOp.getSymName();
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
if (failed(processName(resultID, varName))) {
return failure();
@@ -332,7 +333,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
// Encode initialization.
- if (auto initializer = varOp.initializer()) {
+ if (auto initializer = varOp.getInitializer()) {
auto initializerID = getVariableID(*initializer);
if (!initializerID) {
return emitError(varOp.getLoc(),
@@ -364,7 +365,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
// Assign <id>s to all blocks so that branches inside the SelectionOp can
// resolve properly.
- auto &body = selectionOp.body();
+ auto &body = selectionOp.getBody();
for (Block &block : body)
getOrCreateBlockID(&block);
@@ -390,7 +391,7 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
lastProcessedWasMergeInst = true;
encodeInstructionInto(
functionBody, spirv::Opcode::OpSelectionMerge,
- {mergeID, static_cast<uint32_t>(selectionOp.selection_control())});
+ {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())});
return success();
};
if (failed(
@@ -420,7 +421,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
// Assign <id>s to all blocks so that branches inside the LoopOp can resolve
// properly. We don't need to assign for the entry block, which is just for
// satisfying MLIR region's structural requirement.
- auto &body = loopOp.body();
+ auto &body = loopOp.getBody();
for (Block &block : llvm::drop_begin(body))
getOrCreateBlockID(&block);
@@ -452,7 +453,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
lastProcessedWasMergeInst = true;
encodeInstructionInto(
functionBody, spirv::Opcode::OpLoopMerge,
- {mergeID, continueID, static_cast<uint32_t>(loopOp.loop_control())});
+ {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())});
return success();
};
if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
@@ -483,12 +484,12 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
LogicalResult Serializer::processBranchConditionalOp(
spirv::BranchConditionalOp condBranchOp) {
- auto conditionID = getValueID(condBranchOp.condition());
+ auto conditionID = getValueID(condBranchOp.getCondition());
auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
- if (auto weights = condBranchOp.branch_weights()) {
+ if (auto weights = condBranchOp.getBranchWeights()) {
for (auto val : weights->getValue())
arguments.push_back(val.cast<IntegerAttr>().getInt());
}
@@ -509,26 +510,26 @@ LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
}
LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
- auto varName = addressOfOp.variable();
+ auto varName = addressOfOp.getVariable();
auto variableID = getVariableID(varName);
if (!variableID) {
return addressOfOp.emitError("unknown result <id> for variable ")
<< varName;
}
- valueIDMap[addressOfOp.pointer()] = variableID;
+ valueIDMap[addressOfOp.getPointer()] = variableID;
return success();
}
LogicalResult
Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
- auto constName = referenceOfOp.spec_const();
+ auto constName = referenceOfOp.getSpecConst();
auto constID = getSpecConstID(constName);
if (!constID) {
return referenceOfOp.emitError(
"unknown result <id> for specialization constant ")
<< constName;
}
- valueIDMap[referenceOfOp.reference()] = constID;
+ valueIDMap[referenceOfOp.getReference()] = constID;
return success();
}
@@ -537,21 +538,21 @@ LogicalResult
Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
SmallVector<uint32_t, 4> operands;
// Add the ExecutionModel.
- operands.push_back(static_cast<uint32_t>(op.execution_model()));
+ operands.push_back(static_cast<uint32_t>(op.getExecutionModel()));
// Add the function <id>.
- auto funcID = getFunctionID(op.fn());
+ auto funcID = getFunctionID(op.getFn());
if (!funcID) {
return op.emitError("missing <id> for function ")
- << op.fn()
+ << op.getFn()
<< "; function needs to be defined before spv.EntryPoint is "
"serialized";
}
operands.push_back(funcID);
// Add the name of the function.
- spirv::encodeStringLiteralInto(operands, op.fn());
+ spirv::encodeStringLiteralInto(operands, op.getFn());
// Add the interface values.
- if (auto interface = op.interface()) {
+ if (auto interface = op.getInterface()) {
for (auto var : interface.getValue()) {
auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue());
if (!id) {
@@ -571,19 +572,19 @@ LogicalResult
Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
SmallVector<uint32_t, 4> operands;
// Add the function <id>.
- auto funcID = getFunctionID(op.fn());
+ auto funcID = getFunctionID(op.getFn());
if (!funcID) {
return op.emitError("missing <id> for function ")
- << op.fn()
+ << op.getFn()
<< "; function needs to be serialized before ExecutionModeOp is "
"serialized";
}
operands.push_back(funcID);
// Add the ExecutionMode.
- operands.push_back(static_cast<uint32_t>(op.execution_mode()));
+ operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
// Serialize values if any.
- auto values = op.values();
+ auto values = op.getValues();
if (values) {
for (auto &intVal : values.getValue()) {
operands.push_back(static_cast<uint32_t>(
@@ -598,7 +599,7 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
template <>
LogicalResult
Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
- auto funcName = op.callee();
+ auto funcName = op.getCallee();
uint32_t resTypeID = 0;
Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
@@ -609,7 +610,7 @@ Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
auto funcCallID = getNextID();
SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
- for (auto value : op.arguments()) {
+ for (auto value : op.getArguments()) {
auto valueID = getValueID(value);
assert(valueID && "cannot find a value for spv.FunctionCall");
operands.push_back(valueID);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 7c0b9f33f9e77..e482b257ddf43 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -119,7 +119,8 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
binary.clear();
binary.reserve(moduleSize);
- spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID);
+ spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(),
+ nextID);
binary.append(capabilities.begin(), capabilities.end());
binary.append(extensions.begin(), extensions.end());
binary.append(extendedSets.begin(), extendedSets.end());
@@ -166,7 +167,7 @@ uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
}
void Serializer::processCapability() {
- for (auto cap : module.vce_triple()->getCapabilities())
+ for (auto cap : module.getVceTriple()->getCapabilities())
encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
{static_cast<uint32_t>(cap)});
}
@@ -186,7 +187,7 @@ void Serializer::processDebugInfo() {
void Serializer::processExtension() {
llvm::SmallVector<uint32_t, 16> extName;
- for (spirv::Extension ext : module.vce_triple()->getExtensions()) {
+ for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
extName.clear();
spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
@@ -1045,11 +1046,11 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
} else if (auto branchCondOp =
dyn_cast<spirv::BranchConditionalOp>(terminator)) {
Optional<OperandRange> blockOperands;
- if (branchCondOp.trueTarget() == block) {
- blockOperands = branchCondOp.trueTargetOperands();
+ if (branchCondOp.getTrueTarget() == block) {
+ blockOperands = branchCondOp.getTrueTargetOperands();
} else {
- assert(branchCondOp.falseTarget() == block);
- blockOperands = branchCondOp.falseTargetOperands();
+ assert(branchCondOp.getFalseTarget() == block);
+ blockOperands = branchCondOp.getFalseTargetOperands();
}
assert(!blockOperands->empty() &&
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 8329cfe18dc0a..f02ba1b09bc0c 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -1360,7 +1360,7 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & "
"static_cast<{0}::{1}>(1 << i);\n",
enumAttr.getCppNamespace(), enumAttr.getEnumClassName(),
- namedAttr.name);
+ srcOp.getGetterName(namedAttr.name));
os << formatv(
" if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
enumAttr.getUnderlyingType());
@@ -1368,7 +1368,7 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
// For IntEnumAttr, we just need to query the value as a whole.
os << " {\n";
os << formatv(" auto tblgen_attrVal = this->{0}();\n",
- namedAttr.name);
+ srcOp.getGetterName(namedAttr.name));
}
os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
enumAttr.getCppNamespace(), avail.getQueryFnName());
More information about the Mlir-commits
mailing list