[Mlir-commits] [mlir] [mlir][Affine] Let affine.[de]linearize_index omit outer bounds (PR #116103)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 13 20:28:06 PST 2024


================
@@ -4503,62 +4504,81 @@ LogicalResult AffineVectorStoreOp::verify() {
 // DelinearizeIndexOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
-    MLIRContext *context, std::optional<::mlir::Location> location,
-    ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
-    RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
-  AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
-                                          regions);
-  inferredReturnTypes.assign(adaptor.getStaticBasis().size(),
-                             IndexType::get(context));
-  return success();
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                     OperationState &odsState,
+                                     Value linearIndex, ValueRange dynamicBasis,
+                                     ArrayRef<int64_t> staticBasis,
+                                     bool hasOuterBound) {
+  SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
+                                              : staticBasis.size() + 1,
+                                linearIndex.getType());
+  build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
+        staticBasis);
 }
 
 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
                                      OperationState &odsState,
-                                     Value linearIndex, ValueRange basis) {
+                                     Value linearIndex, ValueRange basis,
+                                     bool hasOuterBound) {
   SmallVector<Value> dynamicBasis;
   SmallVector<int64_t> staticBasis;
   dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
                              staticBasis);
-  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
+        hasOuterBound);
 }
 
 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
                                      OperationState &odsState,
                                      Value linearIndex,
-                                     ArrayRef<OpFoldResult> basis) {
+                                     ArrayRef<OpFoldResult> basis,
+                                     bool hasOuterBound) {
   SmallVector<Value> dynamicBasis;
   SmallVector<int64_t> staticBasis;
   dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
-  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
+        hasOuterBound);
 }
 
 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
                                      OperationState &odsState,
-                                     Value linearIndex,
-                                     ArrayRef<int64_t> basis) {
-  build(odsBuilder, odsState, linearIndex, ValueRange{}, basis);
+                                     Value linearIndex, ArrayRef<int64_t> basis,
+                                     bool hasOuterBound) {
+  build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);
 }
 
 LogicalResult AffineDelinearizeIndexOp::verify() {
-  if (getStaticBasis().empty())
-    return emitOpError("basis should not be empty");
-  if (getNumResults() != getStaticBasis().size())
-    return emitOpError("should return an index for each basis element");
-  auto dynamicMarkersCount =
-      llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
+  ArrayRef<int64_t> staticBasis = getStaticBasis();
+  if (getNumResults() != staticBasis.size() &&
+      getNumResults() != staticBasis.size() + 1)
+    return emitOpError("should return an index for each basis element and up "
----------------
MaheshRavishankar wrote:

Nit: for multi-line statements its clearer to just add `{` `}`

https://github.com/llvm/llvm-project/pull/116103


More information about the Mlir-commits mailing list