[Mlir-commits] [mlir] [MLIR][XeGPU] make offsets optional for create_nd_tdesc (PR #148335)

Jianhui Li llvmlistbot at llvm.org
Tue Jul 15 12:43:33 PDT 2025


================
@@ -221,6 +221,246 @@ LogicalResult CreateNdDescOp::verify() {
   return success();
 }
 
+ParseResult parseOptionalDynamicIndexList(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+    DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
+    SmallVectorImpl<Type> *valueTypes = nullptr,
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+
+  SmallVector<int64_t, 4> integerVals;
+  SmallVector<bool, 4> scalableVals;
+  auto parseIntegerOrValue = [&]() {
+    OpAsmParser::UnresolvedOperand operand;
+    auto res = parser.parseOptionalOperand(operand);
+
+    // When encountering `[`, assume that this is a scalable index.
+    scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
+
+    if (res.has_value() && succeeded(res.value())) {
+      values.push_back(operand);
+      integerVals.push_back(ShapedType::kDynamic);
+      if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
+        return failure();
+    } else {
+      int64_t integer;
+      if (failed(parser.parseInteger(integer)))
+        return failure();
+      integerVals.push_back(integer);
+    }
+
+    // If this is assumed to be a scalable index, verify that there's a closing
+    // `]`.
+    if (scalableVals.back() && parser.parseOptionalRSquare().failed())
+      return failure();
+    return success();
+  };
+  if (parser.parseOptionalLSquare().succeeded()) {
+    if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
+        parser.parseRSquare())
+      return parser.emitError(parser.getNameLoc())
+             << "expected SSA value or integer";
+    integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
+    scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
+    return success();
+  }
+  return success();
+}
+
+::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser,
+                                          ::mlir::OperationState &result) {
+  ::mlir::OpAsmParser::UnresolvedOperand sourceRawOperand{};
+  ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> sourceOperands(
+      &sourceRawOperand, 1);
+  ::llvm::SMLoc sourceOperandsLoc;
+
+  ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
+      offsetsOperands;
+  ::llvm::SMLoc offsetsOperandsLoc;
+  ::mlir::DenseI64ArrayAttr const_offsetsAttr;
+  ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> shapeOperands;
+  ::llvm::SMLoc shapeOperandsLoc;
+  ::mlir::DenseI64ArrayAttr const_shapeAttr;
+  ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
+      stridesOperands;
+  ::llvm::SMLoc stridesOperandsLoc;
+  ::mlir::DenseI64ArrayAttr const_stridesAttr;
+  ::mlir::Type sourceRawType{};
+  ::llvm::ArrayRef<::mlir::Type> sourceTypes(&sourceRawType, 1);
+  ::mlir::Type TensorDescRawType{};
+  ::llvm::ArrayRef<::mlir::Type> TensorDescTypes(&TensorDescRawType, 1);
+
+  sourceOperandsLoc = parser.getCurrentLocation();
+  if (parser.parseOperand(sourceRawOperand))
+    return ::mlir::failure();
+
+  offsetsOperandsLoc = parser.getCurrentLocation();
+
+  DenseBoolArrayAttr scalableFlags;
+  auto odsResult = parseOptionalDynamicIndexList(
+      parser, offsetsOperands, const_offsetsAttr, scalableFlags);
+
+  if (const_offsetsAttr) {
+    if (odsResult)
+      return ::mlir::failure();
+    result.getOrAddProperties<CreateNdDescOp::Properties>().const_offsets =
+        const_offsetsAttr;
+  }
+
+  if (::mlir::succeeded(parser.parseOptionalKeyword("shape"))) {
+    if (parser.parseColon())
+      return ::mlir::failure();
+    {
+      shapeOperandsLoc = parser.getCurrentLocation();
+      auto odsResult =
+          parseDynamicIndexList(parser, shapeOperands, const_shapeAttr);
+      if (const_shapeAttr) {
+        if (odsResult)
+          return ::mlir::failure();
+        result.getOrAddProperties<CreateNdDescOp::Properties>().const_shape =
+            const_shapeAttr;
+      }
+    }
+
+    if (parser.parseKeyword("strides"))
+      return ::mlir::failure();
+    if (parser.parseColon())
+      return ::mlir::failure();
+    {
+      stridesOperandsLoc = parser.getCurrentLocation();
+      auto odsResult =
+          parseDynamicIndexList(parser, stridesOperands, const_stridesAttr);
+      if (const_stridesAttr) {
+        if (odsResult)
+          return ::mlir::failure();
+        result.getOrAddProperties<CreateNdDescOp::Properties>().const_strides =
+            const_stridesAttr;
+      }
+    }
+  }
+  {
+    auto loc = parser.getCurrentLocation();
+    if (parser.parseOptionalAttrDict(result.attributes))
+      return ::mlir::failure();
+    if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+          return parser.emitError(loc)
+                 << "'" << result.name.getStringRef() << "' op ";
+        })))
+      return ::mlir::failure();
+  }
+  if (parser.parseColon())
+    return ::mlir::failure();
+
+  {
+    ::mlir::Type type;
+    if (parser.parseCustomTypeWithFallback(type))
+      return ::mlir::failure();
+    sourceRawType = type;
+  }
+  if (parser.parseArrow())
+    return ::mlir::failure();
+
+  if (parser.parseType(TensorDescRawType))
+    return ::mlir::failure();
+
+  ::llvm::copy(::llvm::ArrayRef<int32_t>(
+                   {1, static_cast<int32_t>(offsetsOperands.size()),
+                    static_cast<int32_t>(shapeOperands.size()),
+                    static_cast<int32_t>(stridesOperands.size())}),
+               result.getOrAddProperties<CreateNdDescOp::Properties>()
+                   .operandSegmentSizes.begin());
+
+  ::mlir::Type odsBuildableType0 = parser.getBuilder().getIndexType();
+  result.addTypes(TensorDescTypes);
+
+  if (parser.resolveOperands(sourceOperands, sourceTypes, sourceOperandsLoc,
+                             result.operands))
+    return ::mlir::failure();
+
+  if (parser.resolveOperands(offsetsOperands, odsBuildableType0,
+                             offsetsOperandsLoc, result.operands))
+    return ::mlir::failure();
+
+  if (parser.resolveOperands(shapeOperands, odsBuildableType0, shapeOperandsLoc,
+                             result.operands))
+    return ::mlir::failure();
+
+  if (parser.resolveOperands(stridesOperands, odsBuildableType0,
+                             stridesOperandsLoc, result.operands))
+    return ::mlir::failure();
+  return ::mlir::success();
+}
+
+void CreateNdDescOp::print(::mlir::OpAsmPrinter &_odsPrinter) {
+  _odsPrinter << ' ';
+  _odsPrinter << getSource();
----------------
Jianhui-Li wrote:

removed. 

https://github.com/llvm/llvm-project/pull/148335


More information about the Mlir-commits mailing list