[Mlir-commits] [mlir] [MLIR][XeGPU] make offsets optional for create_nd_tdesc (PR #148335)
Jianhui Li
llvmlistbot at llvm.org
Tue Jul 15 12:43:33 PDT 2025
================
@@ -221,6 +221,246 @@ LogicalResult CreateNdDescOp::verify() {
return success();
}
+ParseResult parseOptionalDynamicIndexList(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
+ SmallVectorImpl<Type> *valueTypes = nullptr,
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+
+ SmallVector<int64_t, 4> integerVals;
+ SmallVector<bool, 4> scalableVals;
+ auto parseIntegerOrValue = [&]() {
+ OpAsmParser::UnresolvedOperand operand;
+ auto res = parser.parseOptionalOperand(operand);
+
+ // When encountering `[`, assume that this is a scalable index.
+ scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
+
+ if (res.has_value() && succeeded(res.value())) {
+ values.push_back(operand);
+ integerVals.push_back(ShapedType::kDynamic);
+ if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
+ return failure();
+ } else {
+ int64_t integer;
+ if (failed(parser.parseInteger(integer)))
+ return failure();
+ integerVals.push_back(integer);
+ }
+
+ // If this is assumed to be a scalable index, verify that there's a closing
+ // `]`.
+ if (scalableVals.back() && parser.parseOptionalRSquare().failed())
+ return failure();
+ return success();
+ };
+ if (parser.parseOptionalLSquare().succeeded()) {
+ if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
+ parser.parseRSquare())
+ return parser.emitError(parser.getNameLoc())
+ << "expected SSA value or integer";
+ integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
+ scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
+ return success();
+ }
+ return success();
+}
+
+::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result) {
+ ::mlir::OpAsmParser::UnresolvedOperand sourceRawOperand{};
+ ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> sourceOperands(
+ &sourceRawOperand, 1);
+ ::llvm::SMLoc sourceOperandsLoc;
+
+ ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
+ offsetsOperands;
+ ::llvm::SMLoc offsetsOperandsLoc;
+ ::mlir::DenseI64ArrayAttr const_offsetsAttr;
+ ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> shapeOperands;
+ ::llvm::SMLoc shapeOperandsLoc;
+ ::mlir::DenseI64ArrayAttr const_shapeAttr;
+ ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
+ stridesOperands;
+ ::llvm::SMLoc stridesOperandsLoc;
+ ::mlir::DenseI64ArrayAttr const_stridesAttr;
+ ::mlir::Type sourceRawType{};
+ ::llvm::ArrayRef<::mlir::Type> sourceTypes(&sourceRawType, 1);
+ ::mlir::Type TensorDescRawType{};
+ ::llvm::ArrayRef<::mlir::Type> TensorDescTypes(&TensorDescRawType, 1);
+
+ sourceOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperand(sourceRawOperand))
+ return ::mlir::failure();
+
+ offsetsOperandsLoc = parser.getCurrentLocation();
+
+ DenseBoolArrayAttr scalableFlags;
+ auto odsResult = parseOptionalDynamicIndexList(
+ parser, offsetsOperands, const_offsetsAttr, scalableFlags);
+
+ if (const_offsetsAttr) {
+ if (odsResult)
+ return ::mlir::failure();
+ result.getOrAddProperties<CreateNdDescOp::Properties>().const_offsets =
+ const_offsetsAttr;
+ }
+
+ if (::mlir::succeeded(parser.parseOptionalKeyword("shape"))) {
+ if (parser.parseColon())
+ return ::mlir::failure();
+ {
+ shapeOperandsLoc = parser.getCurrentLocation();
+ auto odsResult =
+ parseDynamicIndexList(parser, shapeOperands, const_shapeAttr);
+ if (const_shapeAttr) {
+ if (odsResult)
+ return ::mlir::failure();
+ result.getOrAddProperties<CreateNdDescOp::Properties>().const_shape =
+ const_shapeAttr;
+ }
+ }
+
+ if (parser.parseKeyword("strides"))
+ return ::mlir::failure();
+ if (parser.parseColon())
+ return ::mlir::failure();
+ {
+ stridesOperandsLoc = parser.getCurrentLocation();
+ auto odsResult =
+ parseDynamicIndexList(parser, stridesOperands, const_stridesAttr);
+ if (const_stridesAttr) {
+ if (odsResult)
+ return ::mlir::failure();
+ result.getOrAddProperties<CreateNdDescOp::Properties>().const_strides =
+ const_stridesAttr;
+ }
+ }
+ }
+ {
+ auto loc = parser.getCurrentLocation();
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return ::mlir::failure();
+ if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+ return parser.emitError(loc)
+ << "'" << result.name.getStringRef() << "' op ";
+ })))
+ return ::mlir::failure();
+ }
+ if (parser.parseColon())
+ return ::mlir::failure();
+
+ {
+ ::mlir::Type type;
+ if (parser.parseCustomTypeWithFallback(type))
+ return ::mlir::failure();
+ sourceRawType = type;
+ }
+ if (parser.parseArrow())
+ return ::mlir::failure();
+
+ if (parser.parseType(TensorDescRawType))
+ return ::mlir::failure();
+
+ ::llvm::copy(::llvm::ArrayRef<int32_t>(
+ {1, static_cast<int32_t>(offsetsOperands.size()),
+ static_cast<int32_t>(shapeOperands.size()),
+ static_cast<int32_t>(stridesOperands.size())}),
+ result.getOrAddProperties<CreateNdDescOp::Properties>()
+ .operandSegmentSizes.begin());
+
+ ::mlir::Type odsBuildableType0 = parser.getBuilder().getIndexType();
+ result.addTypes(TensorDescTypes);
+
+ if (parser.resolveOperands(sourceOperands, sourceTypes, sourceOperandsLoc,
+ result.operands))
+ return ::mlir::failure();
+
+ if (parser.resolveOperands(offsetsOperands, odsBuildableType0,
+ offsetsOperandsLoc, result.operands))
+ return ::mlir::failure();
+
+ if (parser.resolveOperands(shapeOperands, odsBuildableType0, shapeOperandsLoc,
+ result.operands))
+ return ::mlir::failure();
+
+ if (parser.resolveOperands(stridesOperands, odsBuildableType0,
+ stridesOperandsLoc, result.operands))
+ return ::mlir::failure();
+ return ::mlir::success();
+}
+
+void CreateNdDescOp::print(::mlir::OpAsmPrinter &_odsPrinter) {
+ _odsPrinter << ' ';
+ _odsPrinter << getSource();
----------------
Jianhui-Li wrote:
removed.
https://github.com/llvm/llvm-project/pull/148335
More information about the Mlir-commits
mailing list