[Mlir-commits] [mlir] [mlir][affine] Add static basis support to affine.delinearize (PR #113846)

Krzysztof Drewniak llvmlistbot at llvm.org
Mon Nov 4 10:34:13 PST 2024


https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/113846

>From 96df4429d9a56512663b643b42c8a695fd5135d0 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Sun, 27 Oct 2024 23:31:35 +0000
Subject: [PATCH 1/4] [mlir][affine] Add static basis support to
 affine.delinearize

This commit makes `affine.delinealize` join other indexing operators,
like `vector.extract`, which store a mixed static/dynamic set of
sizes, offsets, or such. In this case, the `basis` (the set of values
that will be used to decompose the linear index) is now stored as an
array of index attributes where the basis is statically known,
eliminating the need to cretae constants.

This commit also adds copies of the delinearize utility in the affine
dialect to allow it to take an array of `OpFoldResult`s and extends te
DynamicIndexList parser/printer to allow specifying the delimiters in
tablegen (this is needed to avoid breaking existing syntax).
---
 .../mlir/Dialect/Affine/IR/AffineOps.h        |  2 +-
 .../mlir/Dialect/Affine/IR/AffineOps.td       | 21 ++++-
 mlir/include/mlir/Dialect/Affine/Utils.h      |  3 +
 .../mlir/Interfaces/ViewLikeInterface.h       | 16 ++++
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 64 +++++++++-----
 .../Transforms/AffineExpandIndexOps.cpp       |  5 +-
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       | 38 +++++++++
 .../AffineToStandard/lower-affine.mlir        | 83 +++++++++----------
 .../Affine/affine-expand-index-ops.mlir       |  5 +-
 mlir/test/Dialect/Affine/canonicalize.mlir    | 11 +--
 mlir/test/Dialect/Affine/loop-coalescing.mlir | 27 ++----
 mlir/test/Dialect/Affine/ops.mlir             |  7 ++
 .../extract-slice-from-collapse-shape.mlir    | 30 ++-----
 .../Vector/vector-warp-distribute.mlir        |  2 +-
 mlir/test/python/dialects/affine.py           |  3 +-
 15 files changed, 186 insertions(+), 131 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 5c75e102c3d404..7c950623f77f48 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -16,11 +16,11 @@
 
 #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
-
 namespace mlir {
 namespace affine {
 
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 8773fc5881461a..f53b5d97a7156a 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1084,17 +1084,32 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
     ```
   }];
 
-  let arguments = (ins Index:$linear_index, Variadic<Index>:$basis);
+  let arguments = (ins Index:$linear_index,
+    Variadic<Index>:$dynamic_basis,
+    DenseI64ArrayAttr:$static_basis);
   let results = (outs Variadic<Index>:$multi_index);
 
   let assemblyFormat = [{
-    $linear_index `into` ` ` `(` $basis `)` attr-dict `:` type($multi_index)
+    $linear_index `into` ` `
+    custom<DynamicIndexList>($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren")
+    attr-dict `:` type($multi_index)
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>
+    OpBuilder<(ins "Value":$linear_index, "ValueRange":$basis)>,
+    OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>,
+    OpBuilder<(ins "Value":$linear_index, "ArrayRef<int64_t>":$basis)>
   ];
 
+  let extraClassDeclaration = [{
+    /// Return a vector with all the static and dynamic basis values.
+    SmallVector<OpFoldResult> getMixedBasis() {
+      OpBuilder builder(getContext());
+      return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+    }
+
+  }];
+
   let hasVerifier = 1;
   let hasCanonicalizer = 1;
 }
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 9a2767e0ad87f3..d2cfbaa85a60ef 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -311,6 +311,9 @@ DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
 FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
                                                Value linearIndex,
                                                ArrayRef<Value> basis);
+FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
+                                               Value linearIndex,
+                                               ArrayRef<OpFoldResult> basis);
 // Generate IR that extracts the linear index from a multi-index according to
 // a basis/shape.
 OpFoldResult linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index d6479143a0a50b..3dcbd2f1af1936 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -109,6 +109,13 @@ void printDynamicIndexList(
     ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
     TypeRange valueTypes = TypeRange(),
     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
+                                  OperandRange values,
+                                  ArrayRef<int64_t> integers,
+                                  AsmParser::Delimiter delimiter) {
+  return printDynamicIndexList(printer, op, values, integers, {}, TypeRange(),
+                               delimiter);
+}
 inline void printDynamicIndexList(
     OpAsmPrinter &printer, Operation *op, OperandRange values,
     ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
@@ -144,6 +151,15 @@ ParseResult parseDynamicIndexList(
     DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
     SmallVectorImpl<Type> *valueTypes = nullptr,
     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+inline ParseResult
+parseDynamicIndexList(OpAsmParser &parser,
+                      SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+                      DenseI64ArrayAttr &integers,
+                      AsmParser::Delimiter delimiter) {
+  DenseBoolArrayAttr scalableVals = {};
+  return parseDynamicIndexList(parser, values, integers, scalableVals, nullptr,
+                               delimiter);
+}
 inline ParseResult parseDynamicIndexList(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 5e7a6b6ca883c3..f384f454bc4726 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/IntegerSet.h"
@@ -4508,32 +4509,50 @@ LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
     RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
   AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
                                           regions);
-  inferredReturnTypes.assign(adaptor.getBasis().size(),
+  inferredReturnTypes.assign(adaptor.getStaticBasis().size(),
                              IndexType::get(context));
   return success();
 }
 
-void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                     OperationState &odsState,
+                                     Value linearIndex, ValueRange basis) {
+  SmallVector<Value> dynamicBasis;
+  SmallVector<int64_t> staticBasis;
+  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
+                             staticBasis);
+  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+}
+
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                     OperationState &odsState,
                                      Value linearIndex,
                                      ArrayRef<OpFoldResult> basis) {
-  result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType()));
-  result.addOperands(linearIndex);
-  SmallVector<Value> basisValues =
-      llvm::map_to_vector(basis, [&](OpFoldResult ofr) -> Value {
-        std::optional<int64_t> staticDim = getConstantIntValue(ofr);
-        if (staticDim.has_value())
-          return builder.create<arith::ConstantIndexOp>(result.location,
-                                                        *staticDim);
-        return llvm::dyn_cast_if_present<Value>(ofr);
-      });
-  result.addOperands(basisValues);
+  SmallVector<Value> dynamicBasis;
+  SmallVector<int64_t> staticBasis;
+  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
+  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+}
+
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                     OperationState &odsState,
+                                     Value linearIndex,
+                                     ArrayRef<int64_t> basis) {
+  build(odsBuilder, odsState, linearIndex, ValueRange{}, basis);
 }
 
 LogicalResult AffineDelinearizeIndexOp::verify() {
-  if (getBasis().empty())
+  if (getStaticBasis().empty())
     return emitOpError("basis should not be empty");
-  if (getNumResults() != getBasis().size())
+  if (getNumResults() != getStaticBasis().size())
     return emitOpError("should return an index for each basis element");
+  auto dynamicMarkersCount =
+      llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
+  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
+    return emitOpError(
+        "mismatch between dynamic and static basis (kDynamic marker but no "
+        "corresponding dynamic basis entry) -- this can only happen due to an "
+        "incorrect fold/rewrite");
   return success();
 }
 
@@ -4557,15 +4576,16 @@ struct DropUnitExtentBasis
 
     // Replace all indices corresponding to unit-extent basis with 0.
     // Remaining basis can be used to get a new `affine.delinearize_index` op.
-    SmallVector<Value> newOperands;
-    for (auto [index, basis] : llvm::enumerate(delinearizeOp.getBasis())) {
-      if (matchPattern(basis, m_One()))
+    SmallVector<OpFoldResult> newOperands;
+    for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) {
+      std::optional<int64_t> basisVal = getConstantIntValue(basis);
+      if (basisVal && *basisVal == 1)
         replacements[index] = getZero();
       else
         newOperands.push_back(basis);
     }
 
-    if (newOperands.size() == delinearizeOp.getBasis().size())
+    if (newOperands.size() == delinearizeOp.getStaticBasis().size())
       return failure();
 
     if (!newOperands.empty()) {
@@ -4607,9 +4627,9 @@ struct DropDelinearizeOfSingleLoop
 
   LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
                                 PatternRewriter &rewriter) const override {
-    auto basis = delinearizeOp.getBasis();
-    if (basis.size() != 1)
+    if (delinearizeOp.getStaticBasis().size() != 1)
       return failure();
+    auto basis = delinearizeOp.getMixedBasis();
 
     // Check that the `linear_index` is an induction variable.
     auto inductionVar = dyn_cast<BlockArgument>(delinearizeOp.getLinearIndex());
@@ -4634,7 +4654,7 @@ struct DropDelinearizeOfSingleLoop
     // Check that the upper-bound is the basis.
     auto upperBounds = loopLikeOp.getLoopUpperBounds();
     if (!upperBounds || upperBounds->size() != 1 ||
-        upperBounds->front() != getAsOpFoldResult(basis.front())) {
+        upperBounds->front() != basis.front()) {
       return rewriter.notifyMatchFailure(delinearizeOp,
                                          "`basis` is not upper bound");
     }
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index c6bc3862256a75..d76968d3a71520 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -35,9 +35,8 @@ struct LowerDelinearizeIndexOps
   using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    FailureOr<SmallVector<Value>> multiIndex =
-        delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
-                         llvm::to_vector(op.getBasis()));
+    FailureOr<SmallVector<Value>> multiIndex = delinearizeIndex(
+        rewriter, op->getLoc(), op.getLinearIndex(), op.getMixedBasis());
     if (failed(multiIndex))
       return failure();
     rewriter.replaceOp(op, *multiIndex);
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 910ad1733d03e8..e3b5d26e0ec3c3 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1944,6 +1944,18 @@ static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
   return result;
 }
 
+static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
+                                               ArrayRef<OpFoldResult> set) {
+  if (set.empty())
+    return failure();
+  OpFoldResult result = set[0];
+  AffineExpr s0, s1;
+  bindSymbols(b.getContext(), s0, s1);
+  for (unsigned i = 1, e = set.size(); i < e; i++)
+    result = makeComposedFoldedAffineApply(b, loc, s0 * s1, {result, set[i]});
+  return result;
+}
+
 FailureOr<SmallVector<Value>>
 mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
                                ArrayRef<Value> basis) {
@@ -1970,6 +1982,32 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
   return results;
 }
 
+FailureOr<SmallVector<Value>>
+mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
+                               ArrayRef<OpFoldResult> basis) {
+  unsigned numDims = basis.size();
+
+  SmallVector<Value> divisors;
+  for (unsigned i = 1; i < numDims; i++) {
+    ArrayRef<OpFoldResult> slice = basis.drop_front(i);
+    FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, slice);
+    if (failed(prod))
+      return failure();
+    divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod));
+  }
+
+  SmallVector<Value> results;
+  results.reserve(divisors.size() + 1);
+  Value residual = linearIndex;
+  for (Value divisor : divisors) {
+    DivModValue divMod = getDivMod(b, loc, residual, divisor);
+    results.push_back(divMod.quotient);
+    residual = divMod.remainder;
+  }
+  results.push_back(residual);
+  return results;
+}
+
 OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
                                           ArrayRef<OpFoldResult> basis,
                                           ImplicitLocOpBuilder &builder) {
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 23e0edd510cbb1..298e82df4f4cea 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -931,53 +931,48 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me
 ///////////////////////////////////////////////////////////////////////
 
 func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index) {
-  %b0 = arith.constant 16 : index
-  %b1 = arith.constant 224 : index
-  %b2 = arith.constant 224 : index
-  %1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index
+  %1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index
   return %1#0, %1#1, %1#2 : index, index, index
 }
 // CHECK-LABEL:   func.func @test_dilinearize_index(
 // CHECK-SAME:                                      %[[VAL_0:.*]]: index) -> (index, index, index) {
-// CHECK:           %[[VAL_1:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 50176 : index
 // CHECK:           %[[VAL_2:.*]] = arith.constant 224 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant 224 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 50176 : index
-// CHECK:           %[[VAL_5:.*]] = arith.constant 50176 : index
-// CHECK:           %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_7:.*]] = arith.constant -1 : index
-// CHECK:           %[[VAL_8:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_6]] : index
-// CHECK:           %[[VAL_9:.*]] = arith.subi %[[VAL_7]], %[[VAL_0]] : index
-// CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_8]], %[[VAL_9]], %[[VAL_0]] : index
-// CHECK:           %[[VAL_11:.*]] = arith.divsi %[[VAL_10]], %[[VAL_5]] : index
-// CHECK:           %[[VAL_12:.*]] = arith.subi %[[VAL_7]], %[[VAL_11]] : index
-// CHECK:           %[[VAL_13:.*]] = arith.select %[[VAL_8]], %[[VAL_12]], %[[VAL_11]] : index
-// CHECK:           %[[VAL_14:.*]] = arith.constant 50176 : index
-// CHECK:           %[[VAL_15:.*]] = arith.remsi %[[VAL_0]], %[[VAL_14]] : index
-// CHECK:           %[[VAL_16:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_17:.*]] = arith.cmpi slt, %[[VAL_15]], %[[VAL_16]] : index
-// CHECK:           %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_14]] : index
-// CHECK:           %[[VAL_19:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_15]] : index
-// CHECK:           %[[VAL_20:.*]] = arith.constant 50176 : index
-// CHECK:           %[[VAL_21:.*]] = arith.remsi %[[VAL_0]], %[[VAL_20]] : index
-// CHECK:           %[[VAL_22:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_23:.*]] = arith.cmpi slt, %[[VAL_21]], %[[VAL_22]] : index
-// CHECK:           %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_20]] : index
-// CHECK:           %[[VAL_25:.*]] = arith.select %[[VAL_23]], %[[VAL_24]], %[[VAL_21]] : index
-// CHECK:           %[[VAL_26:.*]] = arith.constant 224 : index
-// CHECK:           %[[VAL_27:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_28:.*]] = arith.constant -1 : index
-// CHECK:           %[[VAL_29:.*]] = arith.cmpi slt, %[[VAL_25]], %[[VAL_27]] : index
-// CHECK:           %[[VAL_30:.*]] = arith.subi %[[VAL_28]], %[[VAL_25]] : index
-// CHECK:           %[[VAL_31:.*]] = arith.select %[[VAL_29]], %[[VAL_30]], %[[VAL_25]] : index
-// CHECK:           %[[VAL_32:.*]] = arith.divsi %[[VAL_31]], %[[VAL_26]] : index
-// CHECK:           %[[VAL_33:.*]] = arith.subi %[[VAL_28]], %[[VAL_32]] : index
-// CHECK:           %[[VAL_34:.*]] = arith.select %[[VAL_29]], %[[VAL_33]], %[[VAL_32]] : index
-// CHECK:           %[[VAL_35:.*]] = arith.constant 224 : index
-// CHECK:           %[[VAL_36:.*]] = arith.remsi %[[VAL_0]], %[[VAL_35]] : index
-// CHECK:           %[[VAL_37:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_38:.*]] = arith.cmpi slt, %[[VAL_36]], %[[VAL_37]] : index
-// CHECK:           %[[VAL_39:.*]] = arith.addi %[[VAL_36]], %[[VAL_35]] : index
-// CHECK:           %[[VAL_40:.*]] = arith.select %[[VAL_38]], %[[VAL_39]], %[[VAL_36]] : index
-// CHECK:           return %[[VAL_13]], %[[VAL_34]], %[[VAL_40]] : index, index, index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 50176 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant -1 : index
+// CHECK:           %[[VAL_6:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_4]] : index
+// CHECK:           %[[VAL_7:.*]] = arith.subi %[[VAL_5]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_8:.*]] = arith.select %[[VAL_6]], %[[VAL_7]], %[[VAL_0]] : index
+// CHECK:           %[[VAL_9:.*]] = arith.divsi %[[VAL_8]], %[[VAL_3]] : index
+// CHECK:           %[[VAL_10:.*]] = arith.subi %[[VAL_5]], %[[VAL_9]] : index
+// CHECK:           %[[VAL_11:.*]] = arith.select %[[VAL_6]], %[[VAL_10]], %[[VAL_9]] : index
+// CHECK:           %[[VAL_12:.*]] = arith.constant 50176 : index
+// CHECK:           %[[VAL_13:.*]] = arith.remsi %[[VAL_0]], %[[VAL_12]] : index
+// CHECK:           %[[VAL_14:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_15:.*]] = arith.cmpi slt, %[[VAL_13]], %[[VAL_14]] : index
+// CHECK:           %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index
+// CHECK:           %[[VAL_17:.*]] = arith.select %[[VAL_15]], %[[VAL_16]], %[[VAL_13]] : index
+// CHECK:           %[[VAL_18:.*]] = arith.constant 50176 : index
+// CHECK:           %[[VAL_19:.*]] = arith.remsi %[[VAL_0]], %[[VAL_18]] : index
+// CHECK:           %[[VAL_20:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_21:.*]] = arith.cmpi slt, %[[VAL_19]], %[[VAL_20]] : index
+// CHECK:           %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_18]] : index
+// CHECK:           %[[VAL_23:.*]] = arith.select %[[VAL_21]], %[[VAL_22]], %[[VAL_19]] : index
+// CHECK:           %[[VAL_24:.*]] = arith.constant 224 : index
+// CHECK:           %[[VAL_25:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_26:.*]] = arith.constant -1 : index
+// CHECK:           %[[VAL_27:.*]] = arith.cmpi slt, %[[VAL_23]], %[[VAL_25]] : index
+// CHECK:           %[[VAL_28:.*]] = arith.subi %[[VAL_26]], %[[VAL_23]] : index
+// CHECK:           %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_23]] : index
+// CHECK:           %[[VAL_30:.*]] = arith.divsi %[[VAL_29]], %[[VAL_24]] : index
+// CHECK:           %[[VAL_31:.*]] = arith.subi %[[VAL_26]], %[[VAL_30]] : index
+// CHECK:           %[[VAL_32:.*]] = arith.select %[[VAL_27]], %[[VAL_31]], %[[VAL_30]] : index
+// CHECK:           %[[VAL_33:.*]] = arith.constant 224 : index
+// CHECK:           %[[VAL_34:.*]] = arith.remsi %[[VAL_0]], %[[VAL_33]] : index
+// CHECK:           %[[VAL_35:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_36:.*]] = arith.cmpi slt, %[[VAL_34]], %[[VAL_35]] : index
+// CHECK:           %[[VAL_37:.*]] = arith.addi %[[VAL_34]], %[[VAL_33]] : index
+// CHECK:           %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[VAL_37]], %[[VAL_34]] : index
+// CHECK:           return %[[VAL_11]], %[[VAL_32]], %[[VAL_38]] : index, index, index
 // CHECK:         }
diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
index 70b7f397ad4fec..95773206a521e6 100644
--- a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
+++ b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
@@ -11,10 +11,7 @@
 //       CHECK:   %[[Q:.+]] = affine.apply #[[$map2]]()[%[[IDX]]]
 //       CHECK:   return %[[N]], %[[P]], %[[Q]]
 func.func @static_basis(%linear_index: index) -> (index, index, index) {
-  %b0 = arith.constant 16 : index
-  %b1 = arith.constant 224 : index
-  %b2 = arith.constant 224 : index
-  %1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index
+  %1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index
   return %1#0, %1#1, %1#2 : index, index, index
 }
 
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 906ae81c76d115..d78c3b667589b8 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1472,7 +1472,7 @@ func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
 func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 : index) ->
     (index, index, index, index, index, index) {
   %c1 = arith.constant 1 : index
-  %0:6 = affine.delinearize_index %arg0 into (%c1, %arg1, %c1, %c1, %arg2, %c1)
+  %0:6 = affine.delinearize_index %arg0 into (1, %arg1, 1, 1, %arg2, %c1)
       : index, index, index, index, index, index
   return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : index, index, index, index, index, index
 }
@@ -1487,8 +1487,7 @@ func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 :
 // -----
 
 func.func @drop_all_unit_bases(%arg0 : index) -> (index, index) {
-  %c1 = arith.constant 1 : index
-  %0:2 = affine.delinearize_index %arg0 into (%c1, %c1) : index, index
+  %0:2 = affine.delinearize_index %arg0 into (1, 1) : index, index
   return %0#0, %0#1 : index, index
 }
 // CHECK-LABEL: func @drop_all_unit_bases(
@@ -1519,9 +1518,8 @@ func.func @drop_single_loop_delinearize(%arg0 : index, %arg1 : index) -> index {
 
 // CHECK-LABEL: func @delinearize_non_induction_variable
 func.func @delinearize_non_induction_variable(%arg0: memref<?xi32>, %i : index, %t0 : index, %t1 : index, %t2 : index) -> index {
-  %c1024 = arith.constant 1024 : index
   %1 = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 + s0 + s1 * 64 + s2 * 128)>(%i)[%t0, %t1, %t2]
-  %2 = affine.delinearize_index %1 into (%c1024) : index
+  %2 = affine.delinearize_index %1 into (1024) : index
   return %2 : index
 }
 
@@ -1529,7 +1527,6 @@ func.func @delinearize_non_induction_variable(%arg0: memref<?xi32>, %i : index,
 
 // CHECK-LABEL: func @delinearize_non_loop_like
 func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index {
-  %c1024 = arith.constant 1024 : index
-  %2 = affine.delinearize_index %i into (%c1024) : index
+  %2 = affine.delinearize_index %i into (1024) : index
   return %2 : index
 }
diff --git a/mlir/test/Dialect/Affine/loop-coalescing.mlir b/mlir/test/Dialect/Affine/loop-coalescing.mlir
index f6e7b21bc66aba..3be14eaf5c3261 100644
--- a/mlir/test/Dialect/Affine/loop-coalescing.mlir
+++ b/mlir/test/Dialect/Affine/loop-coalescing.mlir
@@ -6,9 +6,6 @@ func.func @one_3d_nest() {
   // upper bound is also the number of iterations.
   // CHECK-DAG: %[[orig_lb:.*]] = arith.constant 0
   // CHECK-DAG: %[[orig_step:.*]] = arith.constant 1
-  // CHECK-DAG: %[[orig_ub_k:.*]] = arith.constant 3
-  // CHECK-DAG: %[[orig_ub_i:.*]] = arith.constant 42
-  // CHECK-DAG: %[[orig_ub_j:.*]] = arith.constant 56
   // CHECK-DAG: %[[range:.*]] = arith.constant 7056
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -25,7 +22,7 @@ func.func @one_3d_nest() {
 
     // Reconstruct original IVs from the linearized one.
     // CHECK: %[[delinearize:.+]]:3 = affine.delinearize_index %[[i]]
-    // CHECK-SAME: into (%[[orig_ub_i]], %[[orig_ub_j]], %[[orig_ub_k]])
+    // CHECK-SAME: into (42, 56, 3)
     scf.for %j = %c0 to %c56 step %c1 {
       scf.for %k = %c0 to %c3 step %c1 {
         // CHECK: "use"(%[[delinearize]]#0, %[[delinearize]]#1, %[[delinearize]]#2)
@@ -73,11 +70,6 @@ func.func @unnormalized_loops() {
   // Normalized lower bound and step for the outer scf.
   // CHECK-DAG: %[[lb_i:.*]] = arith.constant 0
   // CHECK-DAG: %[[step_i:.*]] = arith.constant 1
-  // CHECK-DAG: %[[orig_step_j_and_numiter_i:.*]] = arith.constant 3
-
-  // Number of iterations in the inner loop, the pattern is the same as above,
-  // only capture the final result.
-  // CHECK-DAG: %[[numiter_j:.*]] = arith.constant 4
 
   // CHECK-DAG: %[[range:.*]] = arith.constant 12
 
@@ -97,7 +89,7 @@ func.func @unnormalized_loops() {
     scf.for %j = %c7 to %c17 step %c3 {
       // The IVs are rewritten.
       // CHECK: %[[delinearize:.+]]:2 = affine.delinearize_index %[[i]]
-      // CHECK-SAME: into (%[[orig_step_j_and_numiter_i]], %[[numiter_j]])
+      // CHECK-SAME: into (3, 4)
       // CHECK: %[[orig_j:.*]] = affine.apply affine_map<(d0) -> (d0 * 3 + 7)>(%[[delinearize]]#1)
       // CHECK: %[[orig_i:.*]] = affine.apply affine_map<(d0) -> (d0 * 2 + 5)>(%[[delinearize]]#0)
       // CHECK: "use"(%[[orig_i]], %[[orig_j]])
@@ -111,10 +103,7 @@ func.func @unnormalized_loops() {
 
 func.func @noramalized_loops_with_yielded_iter_args() {
   // CHECK-DAG: %[[orig_lb:.*]] = arith.constant 0
-  // CHECK-DAG: %[[orig_ub_i:.*]] = arith.constant 42
   // CHECK-DAG: %[[orig_step:.*]] = arith.constant 1
-  // CHECK-DAG: %[[orig_ub_j:.*]] = arith.constant 56
-  // CHECK-DAG: %[[orig_ub_k:.*]] = arith.constant 3
   // CHECK-DAG: %[[range:.*]] = arith.constant 7056
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -130,7 +119,7 @@ func.func @noramalized_loops_with_yielded_iter_args() {
     // CHECK-NOT: scf.for
 
     // Reconstruct original IVs from the linearized one.
-    // CHECK: %[[delinearize:.+]]:3 = affine.delinearize_index %[[i]] into (%[[orig_ub_i]], %[[orig_ub_j]], %[[orig_ub_k]])
+    // CHECK: %[[delinearize:.+]]:3 = affine.delinearize_index %[[i]] into (42, 56, 3)
     %1:1 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg1 = %arg0) -> (index){
       %0:1 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg2 = %arg1) -> (index) {
         // CHECK: "use"(%[[delinearize]]#0, %[[delinearize]]#1, %[[delinearize]]#2)
@@ -150,9 +139,6 @@ func.func @noramalized_loops_with_yielded_iter_args() {
 func.func @noramalized_loops_with_shuffled_yielded_iter_args() {
   // CHECK-DAG: %[[orig_lb:.*]] = arith.constant 0
   // CHECK-DAG: %[[orig_step:.*]] = arith.constant 1
-  // CHECK-DAG: %[[orig_ub_k:.*]] = arith.constant 3
-  // CHECK-DAG: %[[orig_ub_i:.*]] = arith.constant 42
-  // CHECK-DAG: %[[orig_ub_j:.*]] = arith.constant 56
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c3 = arith.constant 3 : index
@@ -169,7 +155,7 @@ func.func @noramalized_loops_with_shuffled_yielded_iter_args() {
 
     // Reconstruct original IVs from the linearized one.
     // CHECK: %[[delinearize:.+]]:3 = affine.delinearize_index %[[i]]
-    // CHECK-SAME: into (%[[orig_ub_i]], %[[orig_ub_j]], %[[orig_ub_k]])
+    // CHECK-SAME: into (42, 56, 3)
     %1:2 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg2 = %arg0, %arg3 = %arg1) -> (index, index){
       %0:2 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg4 = %arg2, %arg5 = %arg3) -> (index, index) {
         // CHECK: "use"(%[[delinearize]]#0, %[[delinearize]]#1, %[[delinearize]]#2)
@@ -189,9 +175,6 @@ func.func @noramalized_loops_with_shuffled_yielded_iter_args() {
 func.func @noramalized_loops_with_yielded_non_iter_args() {
   // CHECK-DAG: %[[orig_lb:.*]] = arith.constant 0
   // CHECK-DAG: %[[orig_step:.*]] = arith.constant 1
-  // CHECK-DAG: %[[orig_ub_k:.*]] = arith.constant 3
-  // CHECK-DAG: %[[orig_ub_i:.*]] = arith.constant 42
-  // CHECK-DAG: %[[orig_ub_j:.*]] = arith.constant 56
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c3 = arith.constant 3 : index
@@ -208,7 +191,7 @@ func.func @noramalized_loops_with_yielded_non_iter_args() {
 
     // Reconstruct original IVs from the linearized one.
     // CHECK: %[[delinearize:.+]]:3 = affine.delinearize_index %[[i]]
-    // CHECK-SAME: into (%[[orig_ub_i]], %[[orig_ub_j]], %[[orig_ub_k]])
+    // CHECK-SAME: into (42, 56, 3)
     %1:1 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg1 = %arg0) -> (index){
       %0:1 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg2 = %arg1) -> (index) {
         // CHECK: %[[res:.*]] = "use"(%[[delinearize]]#0, %[[delinearize]]#1, %[[delinearize]]#2)
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index 19ae1584842aea..52ae53adcea9f9 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -275,3 +275,10 @@ func.func @delinearize(%linear_idx: index, %basis0: index, %basis1 :index) -> (i
   %1:2 = affine.delinearize_index %linear_idx into (%basis0, %basis1) : index, index
   return %1#0, %1#1 : index, index
 }
+
+// CHECK-LABEL: @delinearize_mixed
+func.func @delinearize_mixed(%linear_idx: index, %basis1: index) -> (index, index, index) {
+  // CHECK: affine.delinearize_index %{{.+}} into (2, %{{.+}}, 3) : index, index, index
+  %1:3 = affine.delinearize_index %linear_idx into (2, %basis1, 3) : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
diff --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
index 3669cae87408df..4bb099e3401ecf 100644
--- a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
+++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir
@@ -11,12 +11,9 @@ func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf3
 // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
 // CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index
 // CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
-// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index
-// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index
 // CHECK-DAG: %[[init:.+]] = tensor.empty() : tensor<20x11xf32>
 // CHECK-DAG: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c20]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
-//     CHECK:   %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]]
+//     CHECK:   %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (3, 5, 7
 //     CHECK:   %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] :
 //     CHECK:   %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
 //     CHECK:   %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 11] [1, 1] :
@@ -24,12 +21,9 @@ func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf3
 //     CHECK: return %[[tile]]
 
 //     FOREACH: func.func @extract_slice_static(%[[arg0:.+]]:
-// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index
-// FOREACH-DAG: %[[c5:.+]] = arith.constant 5 : index
-// FOREACH-DAG: %[[c7:.+]] = arith.constant 7 : index
 // FOREACH-DAG: %[[init:.+]] = tensor.empty() : tensor<20x11xf32>
 //     FOREACH: %[[tile:.+]] = scf.forall (%[[iv:.+]]) in (20) shared_outs(%[[dest:.+]] = %[[init]])
-//     FOREACH:   %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]]
+//     FOREACH:   %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (3, 5, 7
 //     FOREACH:   %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] :
 //     FOREACH:   %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
 //     FOREACH:   in_parallel
@@ -50,13 +44,10 @@ func.func @extract_slice_static_strided(%input: tensor<3x5x7x11xf32>) -> tensor<
 // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
 // CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
 // CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
-// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
-// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index
-// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index
 //     CHECK: %[[init:.+]] = tensor.empty() : tensor<10x5xf32>
 //     CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
 //     CHECK:   %[[inputIv:.+]] = affine.apply #[[$map0]](%[[iv]])
-//     CHECK:   %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[c5]], %[[c7]]
+//     CHECK:   %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (3, 5, 7
 //     CHECK:   %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] :
 //     CHECK:   %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
 //     CHECK:   %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] :
@@ -78,13 +69,12 @@ func.func @extract_slice_dynamic(%input: tensor<3x?x?x11xf32>, %offt: index, %si
 // CHECK-DAG:   %[[c0:.+]] = arith.constant 0 : index
 // CHECK-DAG:   %[[c1:.+]] = arith.constant 1 : index
 // CHECK-DAG:   %[[c2:.+]] = arith.constant 2 : index
-// CHECK-DAG:   %[[c3:.+]] = arith.constant 3 : index
 //     CHECK:   %[[init:.+]] = tensor.empty(%[[sz]]) : tensor<?x5xf32>
 // CHECK-DAG:   %[[d1:.+]] = tensor.dim %arg0, %[[c1]] : tensor<3x?x?x11xf32>
 // CHECK-DAG:   %[[d2:.+]] = tensor.dim %arg0, %[[c2]] : tensor<3x?x?x11xf32>
 //     CHECK:   %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[sz]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
 //     CHECK:     %[[inputIv:.+]] = affine.apply #[[map0]](%[[iv]])[%[[lb]]]
-//     CHECK:     %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[d1]], %[[d2]]) :
+//     CHECK:     %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (3, %[[d1]], %[[d2]]) :
 //     CHECK:     %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] :
 //     CHECK:     %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
 //     CHECK:     %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] :
@@ -105,9 +95,7 @@ func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0
 // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
 // CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
 // CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
-// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
 // CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index
-// CHECK-DAG: %[[c11:.+]] = arith.constant 11 : index
 //     CHECK: %[[init:.+]] = tensor.empty(%[[sz1]], %[[sz2]]) : tensor<?x?xf32>
 // CHECK-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] :
 // CHECK-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] :
@@ -115,9 +103,9 @@ func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0
 //     CHECK: %[[tile1:.+]] = scf.for %[[iv1:.+]] = %[[c0]] to %[[sz1]] step %[[c1]] iter_args(%[[iterArg1:.+]] = %[[init]])
 //     CHECK:   %[[tile2:.+]] = scf.for %[[iv2:.+]] = %[[c0]] to %[[sz2]] step %[[c1]] iter_args(%[[iterArg2:.+]] = %[[iterArg1]])
 //     CHECK:       %[[inputIv1:.+]] = affine.apply #[[map0:.+]](%[[iv1]])[%[[lb1]]]
-//     CHECK:       %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[inputIv1]] into (%[[c3]], %[[d1]], %[[d2]]) :
+//     CHECK:       %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[inputIv1]] into (3, %[[d1]], %[[d2]]) :
 //     CHECK:       %[[inputIv2:.+]] = affine.apply #[[map0:.+]](%[[iv2]])[%[[lb2]]]
-//     CHECK:       %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[inputIv2]] into (%[[c11]], %[[d4]]) :
+//     CHECK:       %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[inputIv2]] into (11, %[[d4]]) :
 //     CHECK:       %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] :
 //     CHECK:       %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} :
 //     CHECK:       %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg2]][%[[iv1]], %[[iv2]]] [1, 1] [1, 1] :
@@ -129,18 +117,16 @@ func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0
 //     FOREACH: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index)
 // FOREACH-DAG: %[[c1:.+]] = arith.constant 1 : index
 // FOREACH-DAG: %[[c2:.+]] = arith.constant 2 : index
-// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index
 // FOREACH-DAG: %[[c4:.+]] = arith.constant 4 : index
-// FOREACH-DAG: %[[c11:.+]] = arith.constant 11 : index
 //     FOREACH:     %[[init:.+]] = tensor.empty(%[[sz1]], %[[sz2]]) : tensor<?x?xf32>
 // FOREACH-DAG:     %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] :
 // FOREACH-DAG:     %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] :
 // FOREACH-DAG:     %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] :
 //     FOREACH:     %[[tile1:.+]] = scf.forall (%[[tid1:.+]], %[[tid2:.+]]) in (%[[sz1]], %[[sz2]]) shared_outs(%[[dest:.+]] = %[[init]])
 // FOREACH-DAG:       %[[iv1:.+]] = affine.apply #[[map1]](%[[tid1]])[%[[lb1]]]
-//     FOREACH:       %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[iv1]] into (%[[c3]], %[[d1]], %[[d2]]) :
+//     FOREACH:       %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[iv1]] into (3, %[[d1]], %[[d2]]) :
 // FOREACH-DAG:       %[[iv2:.+]] = affine.apply #[[map1]](%[[tid2]])[%[[lb2]]]
-//     FOREACH:       %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[iv2]] into (%[[c11]], %[[d4]]) :
+//     FOREACH:       %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[iv2]] into (11, %[[d4]]) :
 //     FOREACH:       %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] :
 //     FOREACH:       %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} :
 //     FOREACH:       in_parallel
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 0544cef3e38281..3acddd6e54639e 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1615,6 +1615,6 @@ func.func @warp_propagate_nd_write(%laneid: index, %dest: memref<4x1024xf32>) {
 //  CHECK-DIST-AND-PROP-SAME:       vector<4x1024xf32>
 //       CHECK-DIST-AND-PROP:   }
 
-//       CHECK-DIST-AND-PROP:   %[[IDS:.+]]:2 = affine.delinearize_index %{{.*}} into (%c4, %c8) : index, index
+//       CHECK-DIST-AND-PROP:   %[[IDS:.+]]:2 = affine.delinearize_index %{{.*}} into (4, 8) : index, index
 //       CHECK-DIST-AND-PROP:   %[[INNER_ID:.+]] = affine.apply #map()[%[[IDS]]#1]
 //       CHECK-DIST-AND-PROP:   vector.transfer_write %[[W]], %{{.*}}[%[[IDS]]#0, %[[INNER_ID]]] {{.*}} : vector<1x128xf32>
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index 6f39e1348fcd57..58be05a8eb7917 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -47,11 +47,10 @@ def affine_store_test(arg0):
 # CHECK-LABEL: TEST: testAffineDelinearizeInfer
 @constructAndPrintInModule
 def testAffineDelinearizeInfer():
-    # CHECK: %[[C0:.*]] = arith.constant 0 : index
     c0 = arith.ConstantOp(T.index(), 0)
     # CHECK: %[[C1:.*]] = arith.constant 1 : index
     c1 = arith.ConstantOp(T.index(), 1)
-    # CHECK: %{{.*}}:2 = affine.delinearize_index %[[C1:.*]] into (%[[C1:.*]], %[[C0:.*]]) : index, index
+    # CHECK: %{{.*}}:2 = affine.delinearize_index %[[C1:.*]] into (1, 0) : index, index
     two_indices = affine.AffineDelinearizeIndexOp(c1, [c1, c0])
 
 

>From ad15902093d7d1d7f54b747bd64555c11f491efa Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 28 Oct 2024 22:37:31 +0000
Subject: [PATCH 2/4] Adjust Python test

---
 mlir/test/python/dialects/affine.py | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index 58be05a8eb7917..73864708b2b221 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -47,11 +47,10 @@ def affine_store_test(arg0):
 # CHECK-LABEL: TEST: testAffineDelinearizeInfer
 @constructAndPrintInModule
 def testAffineDelinearizeInfer():
-    c0 = arith.ConstantOp(T.index(), 0)
     # CHECK: %[[C1:.*]] = arith.constant 1 : index
     c1 = arith.ConstantOp(T.index(), 1)
-    # CHECK: %{{.*}}:2 = affine.delinearize_index %[[C1:.*]] into (1, 0) : index, index
-    two_indices = affine.AffineDelinearizeIndexOp(c1, [c1, c0])
+    # CHECK: %{{.*}}:2 = affine.delinearize_index %[[C1:.*]] into (2, 3) : index, index
+    two_indices = affine.AffineDelinearizeIndexOp(c1, [], [2, 3])
 
 
 # CHECK-LABEL: TEST: testAffineLoadOp

>From f3696112de92a381b546ad6bd4f490fe084c991f Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 30 Oct 2024 15:41:10 +0000
Subject: [PATCH 3/4] Address review feedback, update Python test

---
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       | 62 ++++++++-----------
 .../AffineToStandard/lower-affine.mlir        |  4 +-
 mlir/test/python/dialects/affine.py           |  2 +-
 3 files changed, 28 insertions(+), 40 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index e3b5d26e0ec3c3..2680502bb687d3 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1931,49 +1931,36 @@ DivModValue mlir::affine::getDivMod(OpBuilder &b, Location loc, Value lhs,
   return result;
 }
 
-/// Create IR that computes the product of all elements in the set.
-static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
-                                               ArrayRef<Value> set) {
-  if (set.empty())
-    return failure();
-  OpFoldResult result = set[0];
-  AffineExpr s0, s1;
-  bindSymbols(b.getContext(), s0, s1);
-  for (unsigned i = 1, e = set.size(); i < e; i++)
-    result = makeComposedFoldedAffineApply(b, loc, s0 * s1, {result, set[i]});
-  return result;
-}
-
-static FailureOr<OpFoldResult> getIndexProduct(OpBuilder &b, Location loc,
-                                               ArrayRef<OpFoldResult> set) {
-  if (set.empty())
-    return failure();
-  OpFoldResult result = set[0];
+/// Create an affine map that computes `lhs` * `rhs`, composing in any other
+/// affine maps.
+static FailureOr<OpFoldResult> composedAffineMultiply(OpBuilder &b,
+                                                      Location loc,
+                                                      OpFoldResult lhs,
+                                                      OpFoldResult rhs) {
   AffineExpr s0, s1;
   bindSymbols(b.getContext(), s0, s1);
-  for (unsigned i = 1, e = set.size(); i < e; i++)
-    result = makeComposedFoldedAffineApply(b, loc, s0 * s1, {result, set[i]});
-  return result;
+  return makeComposedFoldedAffineApply(b, loc, s0 * s1, {lhs, rhs});
 }
 
 FailureOr<SmallVector<Value>>
 mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
                                ArrayRef<Value> basis) {
-  unsigned numDims = basis.size();
-
+  // Note: the divisors are backwards due to the scan.
   SmallVector<Value> divisors;
-  for (unsigned i = 1; i < numDims; i++) {
-    ArrayRef<Value> slice = basis.drop_front(i);
-    FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, slice);
-    if (failed(prod))
+  OpFoldResult basisProd = b.getIndexAttr(1);
+  for (OpFoldResult basisElem : llvm::reverse(basis.drop_front())) {
+    FailureOr<OpFoldResult> nextProd =
+        composedAffineMultiply(b, loc, basisElem, basisProd);
+    if (failed(nextProd))
       return failure();
-    divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod));
+    basisProd = *nextProd;
+    divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, basisProd));
   }
 
   SmallVector<Value> results;
   results.reserve(divisors.size() + 1);
   Value residual = linearIndex;
-  for (Value divisor : divisors) {
+  for (Value divisor : llvm::reverse(divisors)) {
     DivModValue divMod = getDivMod(b, loc, residual, divisor);
     results.push_back(divMod.quotient);
     residual = divMod.remainder;
@@ -1985,21 +1972,22 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
 FailureOr<SmallVector<Value>>
 mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
                                ArrayRef<OpFoldResult> basis) {
-  unsigned numDims = basis.size();
-
+  // Note: the divisors are backwards due to the scan.
   SmallVector<Value> divisors;
-  for (unsigned i = 1; i < numDims; i++) {
-    ArrayRef<OpFoldResult> slice = basis.drop_front(i);
-    FailureOr<OpFoldResult> prod = getIndexProduct(b, loc, slice);
-    if (failed(prod))
+  OpFoldResult basisProd = b.getIndexAttr(1);
+  for (OpFoldResult basisElem : llvm::reverse(basis.drop_front())) {
+    FailureOr<OpFoldResult> nextProd =
+        composedAffineMultiply(b, loc, basisElem, basisProd);
+    if (failed(nextProd))
       return failure();
-    divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, *prod));
+    basisProd = *nextProd;
+    divisors.push_back(getValueOrCreateConstantIndexOp(b, loc, basisProd));
   }
 
   SmallVector<Value> results;
   results.reserve(divisors.size() + 1);
   Value residual = linearIndex;
-  for (Value divisor : divisors) {
+  for (Value divisor : llvm::reverse(divisors)) {
     DivModValue divMod = getDivMod(b, loc, residual, divisor);
     results.push_back(divMod.quotient);
     residual = divMod.remainder;
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 298e82df4f4cea..3781d510897f8f 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -936,8 +936,8 @@ func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index)
 }
 // CHECK-LABEL:   func.func @test_dilinearize_index(
 // CHECK-SAME:                                      %[[VAL_0:.*]]: index) -> (index, index, index) {
-// CHECK:           %[[VAL_1:.*]] = arith.constant 50176 : index
-// CHECK:           %[[VAL_2:.*]] = arith.constant 224 : index
+// CHECK:           %[[VAL_1:.*]] = arith.constant 224 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 50176 : index
 // CHECK:           %[[VAL_3:.*]] = arith.constant 50176 : index
 // CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_5:.*]] = arith.constant -1 : index
diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index 73864708b2b221..0dc69d7ba522de 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -157,7 +157,7 @@ def testAffineForOpErrors():
         )
 
     try:
-        two_indices = affine.AffineDelinearizeIndexOp(c1, [c1, c1])
+        two_indices = affine.AffineDelinearizeIndexOp(c1, [], [1, 1])
         affine.AffineForOp(
             two_indices,
             c2,

>From 06bad6d474535a3f81dfca6b8bc5bc4224125324 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 4 Nov 2024 12:34:03 -0600
Subject: [PATCH 4/4] Formatting fixes

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/include/mlir/Dialect/Affine/IR/AffineOps.td | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index f53b5d97a7156a..e9480d30c2d701 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1102,12 +1102,11 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
   ];
 
   let extraClassDeclaration = [{
-    /// Return a vector with all the static and dynamic basis values.
+    /// Returns a vector with all the static and dynamic basis values.
     SmallVector<OpFoldResult> getMixedBasis() {
       OpBuilder builder(getContext());
       return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
     }
-
   }];
 
   let hasVerifier = 1;



More information about the Mlir-commits mailing list