[Mlir-commits] [mlir] [mlir][affine] Define `affine.linearize_index` (PR #114480)

Krzysztof Drewniak llvmlistbot at llvm.org
Mon Nov 4 10:56:55 PST 2024


https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/114480

>From e640096b8d107d54610fe6507edbc9cc02cbc61f Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 30 Oct 2024 20:53:43 +0000
Subject: [PATCH 1/4] [mlir][affine] Define `affine.linearize_index`

`affine.linearize_index` is the inverse of `affine.delinearize_index`
and general useful for representing computations (like those needed to
move from N-D to 1-D memrefs) that put together indices.

This commit introduces `affine.linearize_index` and one simple
canonicalization for it.

There are plans to add `affine.linearize_index` and
`affine.delinearize_index` pair canonicalizations, but we are saving
those for a followup PR (especially since having #113846 landed would
make them nicer).

Note while `affine` may not be the natural home for this operation,
https://discourse.llvm.org/t/better-location-of-affine-delinearize-operation/80565/13
didn't come to any better consensus location.
---
 .../mlir/Dialect/Affine/IR/AffineOps.h        |   1 +
 .../mlir/Dialect/Affine/IR/AffineOps.td       |  69 ++++++++++++
 mlir/include/mlir/Dialect/Affine/Utils.h      |   3 +
 .../mlir/Interfaces/ViewLikeInterface.h       |  16 +++
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 106 ++++++++++++++++++
 .../Transforms/AffineExpandIndexOps.cpp       |  22 +++-
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       |  14 ++-
 .../AffineToStandard/lower-affine.mlir        |  17 +++
 .../Affine/affine-expand-index-ops.mlir       |  26 +++++
 mlir/test/Dialect/Affine/canonicalize.mlir    |  34 ++++++
 mlir/test/Dialect/Affine/invalid.mlir         |  16 +++
 mlir/test/Dialect/Affine/ops.mlir             |  16 +++
 12 files changed, 335 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 5c75e102c3d404..40ce26d6527cc0 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 8773fc5881461a..c49aa8281f49d1 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1099,4 +1099,73 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// AffineLinearizeIndexOp
+//===----------------------------------------------------------------------===//
+def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
+    [Pure, AttrSizedOperandSegments]> {
+  let summary = "linearize an index";
+  let description = [{
+    The `affine.linearize_index` operation takes a sequence of index values and a
+    basis of the same length and linearizes the indices using that basis.
+
+    That is, for indices %idx_1 through %i_N and basis elements b_1 through b_N,
+    it computes
+
+    ```
+    sum(i = 1 to N) %idx_i * product(j = i + 1 to N) B_j
+    ```
+
+    If the `disjoint` property is present, this is an optimization hint that,
+    for all i, 0 <= %idx_i < B_i - that is, no index affects any other index,
+    except that %idx_0 may be negative to make the index as a whole negative.
+
+    Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`.
+
+    Example:
+
+    ```
+    %linear_index = affine.delinearize_index [%index_0, %index_1, %index_2] (16, 224, 224) : index
+    ```
+
+    In the above example, `%linear_index` conceptually holds the following:
+
+    ```
+    #map = affine_map<()[s0, s1, s2] -> (s0 * 50176 + s1 * 224 + s2)>
+    %linear_index = affine.apply #map()[%index_0, %index_1, %index_2]
+    ```
+  }];
+
+  let arguments = (ins Variadic<Index>:$multi_index,
+    Variadic<Index>:$dynamic_basis,
+    DenseI64ArrayAttr:$static_basis,
+    UnitProperty:$disjoint);
+  let results = (outs Index:$linear_index);
+
+  let assemblyFormat = [{
+    (`disjoint` $disjoint^)? ` `
+    `[` $multi_index `]` `by` ` `
+    custom<DynamicIndexList>($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren")
+    attr-dict `:` type($linear_index)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "ValueRange":$multi_index, "ValueRange":$basis, CArg<"bool", "false">:$disjoint)>,
+    OpBuilder<(ins "ValueRange":$multi_index, "ArrayRef<OpFoldResult>":$basis, CArg<"bool", "false">:$disjoint)>,
+    OpBuilder<(ins "ValueRange":$multi_index, "ArrayRef<int64_t>":$basis, CArg<"bool", "false">:$disjoint)>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Return a vector with all the static and dynamic basis values.
+    SmallVector<OpFoldResult> getMixedBasis() {
+      OpBuilder builder(getContext());
+      return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+    }
+
+  }];
+
+  let hasVerifier = 1;
+  let hasCanonicalizer = 1;
+}
+
 #endif // AFFINE_OPS
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 9a2767e0ad87f3..f518272ec0e8f9 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -316,6 +316,9 @@ FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
 OpFoldResult linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
                             ArrayRef<OpFoldResult> basis,
                             ImplicitLocOpBuilder &builder);
+OpFoldResult linearizeIndex(OpBuilder &builder, Location loc,
+                            ArrayRef<OpFoldResult> multiIndex,
+                            ArrayRef<OpFoldResult> basis);
 
 /// Ensure that all operations that could be executed after `start`
 /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index d6479143a0a50b..3dcbd2f1af1936 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -109,6 +109,13 @@ void printDynamicIndexList(
     ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
     TypeRange valueTypes = TypeRange(),
     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
+                                  OperandRange values,
+                                  ArrayRef<int64_t> integers,
+                                  AsmParser::Delimiter delimiter) {
+  return printDynamicIndexList(printer, op, values, integers, {}, TypeRange(),
+                               delimiter);
+}
 inline void printDynamicIndexList(
     OpAsmPrinter &printer, Operation *op, OperandRange values,
     ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
@@ -144,6 +151,15 @@ ParseResult parseDynamicIndexList(
     DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
     SmallVectorImpl<Type> *valueTypes = nullptr,
     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+inline ParseResult
+parseDynamicIndexList(OpAsmParser &parser,
+                      SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+                      DenseI64ArrayAttr &integers,
+                      AsmParser::Delimiter delimiter) {
+  DenseBoolArrayAttr scalableVals = {};
+  return parseDynamicIndexList(parser, values, integers, scalableVals, nullptr,
+                               delimiter);
+}
 inline ParseResult parseDynamicIndexList(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 5e7a6b6ca883c3..2d3b5a80df4b0e 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4664,6 +4664,112 @@ void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
   patterns.insert<DropDelinearizeOfSingleLoop, DropUnitExtentBasis>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// LinearizeIndexOp
+//===----------------------------------------------------------------------===//
+
+void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                   OperationState &odsState,
+                                   ValueRange multiIndex, ValueRange basis,
+                                   bool disjoint) {
+  SmallVector<Value> dynamicBasis;
+  SmallVector<int64_t> staticBasis;
+  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
+                             staticBasis);
+  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
+}
+
+void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                   OperationState &odsState,
+                                   ValueRange multiIndex,
+                                   ArrayRef<OpFoldResult> basis,
+                                   bool disjoint) {
+  SmallVector<Value> dynamicBasis;
+  SmallVector<int64_t> staticBasis;
+  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
+  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
+}
+
+void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                   OperationState &odsState,
+                                   ValueRange multiIndex,
+                                   ArrayRef<int64_t> basis, bool disjoint) {
+  build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
+}
+
+LogicalResult AffineLinearizeIndexOp::verify() {
+  if (getStaticBasis().empty())
+    return emitOpError("basis should not be empty");
+  if (getMultiIndex().size() != getStaticBasis().size())
+    return emitOpError("should be passed an index for each basis element");
+  auto dynamicMarkersCount =
+      llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
+  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
+    return emitOpError(
+        "mismatch between dynamic and static basis (kDynamic marker but no "
+        "corresponding dynamic basis entry) -- this can only happen due to an "
+        "incorrect fold/rewrite");
+  return success();
+}
+
+namespace {
+/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
+/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
+/// %...d)`.
+
+/// Note that `disjoint` is required here, because, without it, we could have
+/// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
+/// is a valid operation where the `%c64` cannot be trivially dropped.
+///
+/// Alternatively, if `%x` in the above is a known constant 0, remove it even if
+/// the operation isn't asserted to be `disjoint`.
+struct DropLinearizeUnitComponentsIfDisjointOrZero
+    : public OpRewritePattern<affine::AffineLinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
+                                PatternRewriter &rewriter) const override {
+    size_t numIndices = op.getMultiIndex().size();
+    SmallVector<Value> newIndices;
+    newIndices.reserve(numIndices);
+    SmallVector<OpFoldResult> newBasis;
+    newBasis.reserve(numIndices);
+
+    SmallVector<OpFoldResult> basis = op.getMixedBasis();
+    for (auto [index, basisElem] : llvm::zip_equal(op.getMultiIndex(), basis)) {
+      std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
+      if (!basisEntry || *basisEntry != 1) {
+        newIndices.push_back(index);
+        newBasis.push_back(basisElem);
+        continue;
+      }
+
+      std::optional<int64_t> indexValue = getConstantIntValue(index);
+      if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
+        newIndices.push_back(index);
+        newBasis.push_back(basisElem);
+        continue;
+      }
+    }
+    if (newIndices.size() == numIndices)
+      return failure();
+
+    if (newIndices.size() == 0) {
+      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+      return success();
+    }
+    rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
+        op, newIndices, newBasis, op.getDisjoint());
+    return success();
+  }
+};
+} // namespace
+
+void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index c6bc3862256a75..c96c188cbc89fc 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
 #include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
@@ -45,6 +46,24 @@ struct LowerDelinearizeIndexOps
   }
 };
 
+/// Lowers `affine.linearize_index` into a sequence of multiplications and
+/// additions.
+struct LowerLinearizeIndexOps
+    : public OpRewritePattern<AffineLinearizeIndexOp> {
+  using OpRewritePattern<AffineLinearizeIndexOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<OpFoldResult> multiIndex =
+        getAsOpFoldResult(op.getMultiIndex());
+    OpFoldResult linearIndex =
+        linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
+    Value linearIndexValue =
+        getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
+    rewriter.replaceOp(op, linearIndexValue);
+    return success();
+  }
+};
+
 class ExpandAffineIndexOpsPass
     : public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
 public:
@@ -64,7 +83,8 @@ class ExpandAffineIndexOpsPass
 
 void mlir::affine::populateAffineExpandIndexOpsPatterns(
     RewritePatternSet &patterns) {
-  patterns.insert<LowerDelinearizeIndexOps>(patterns.getContext());
+  patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
+      patterns.getContext());
 }
 
 std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsPass() {
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 910ad1733d03e8..76b67555feb132 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1973,6 +1973,12 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
 OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
                                           ArrayRef<OpFoldResult> basis,
                                           ImplicitLocOpBuilder &builder) {
+  return linearizeIndex(builder, builder.getLoc(), multiIndex, basis);
+}
+
+OpFoldResult mlir::affine::linearizeIndex(OpBuilder &builder, Location loc,
+                                          ArrayRef<OpFoldResult> multiIndex,
+                                          ArrayRef<OpFoldResult> basis) {
   assert(multiIndex.size() == basis.size());
   SmallVector<AffineExpr> basisAffine;
   for (size_t i = 0; i < basis.size(); ++i) {
@@ -1983,13 +1989,13 @@ OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
   SmallVector<OpFoldResult> strides;
   strides.reserve(stridesAffine.size());
   llvm::transform(stridesAffine, std::back_inserter(strides),
-                  [&builder, &basis](AffineExpr strideExpr) {
+                  [&builder, &basis, loc](AffineExpr strideExpr) {
                     return affine::makeComposedFoldedAffineApply(
-                        builder, builder.getLoc(), strideExpr, basis);
+                        builder, loc, strideExpr, basis);
                   });
 
   auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex(
       OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex);
-  return affine::makeComposedFoldedAffineApply(
-      builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides);
+  return affine::makeComposedFoldedAffineApply(builder, loc, linearIndexExpr,
+                                               multiIndexAndStrides);
 }
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 23e0edd510cbb1..7e1c25355c3086 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -981,3 +981,20 @@ func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index)
 // CHECK:           %[[VAL_40:.*]] = arith.select %[[VAL_38]], %[[VAL_39]], %[[VAL_36]] : index
 // CHECK:           return %[[VAL_13]], %[[VAL_34]], %[[VAL_40]] : index, index, index
 // CHECK:         }
+
+/////////////////////////////////////////////////////////////////////
+
+func.func @test_linearize_index(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 3, 5) : index
+  return %ret : index
+}
+
+// CHECK-LABEL: @test_linearize_index
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
+// CHECK: %[[c15:.+]] = arith.constant 15 : index
+// CHECK-NEXT: %[[tmp0:.+]] = arith.muli %[[arg0]], %[[c15]] : index
+// CHECK-NEXT: %[[c5:.+]] = arith.constant 5 : index
+// CHECK-NEXT: %[[tmp1:.+]] = arith.muli %[[arg1]], %[[c5]] : index
+// CHECK-NEXT: %[[tmp2:.+]] = arith.addi %[[tmp0]], %[[tmp1]] : index
+// CHECK-NEXT: %[[ret:.+]] = arith.addi %[[tmp2]], %[[arg2]] : index
+// CHECK-NEXT: return %[[ret]]
diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
index 70b7f397ad4fec..fb20ba2ee69e01 100644
--- a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
+++ b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
@@ -44,3 +44,29 @@ func.func @dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (inde
   %1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index
   return %1#0, %1#1, %1#2 : index, index, index
 }
+
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)>
+
+// CHECK-LABEL: @linearize_static
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
+// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg2]]]
+// CHECK: return %[[val_0]]
+func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (2, 3, 5) : index
+  func.return %0 : index
+}
+
+// -----
+
+// CHECK-DAG: #[[$map0:.+]] =  affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s2 + s3 + s0 * (s2 * s4))>
+
+// CHECK-LABEL: @linearize_dynamic
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index, %[[arg5:.+]]: index)
+// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg5]], %[[arg2]], %[[arg4]]]
+// CHECK: return %[[val_0]]
+func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> index {
+  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4, %arg5) : index
+  func.return %0 : index
+}
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 906ae81c76d115..5424e082902609 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1533,3 +1533,37 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
   %2 = affine.delinearize_index %i into (%c1024) : index
   return %2 : index
 }
+
+// -----
+
+// CHECK-LABEL: @linearize_unit_basis_disjoint
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
+// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index
+// CHECK: return %[[ret]]
+func.func @linearize_unit_basis_disjoint(%arg0: index, %arg1: index, %arg2: index, %arg3: index) -> index {
+  %ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (3, 1, %arg3) : index
+  return %ret : index
+}
+
+// -----
+
+// CHECK-LABEL: @linearize_unit_basis_zero
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
+// CHECK: %[[ret:.+]] = affine.linearize_index [%[[arg0]], %[[arg1]]] by (3, %[[arg2]]) : index
+// CHECK: return %[[ret]]
+func.func @linearize_unit_basis_zero(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %c0 = arith.constant 0 : index
+  %ret = affine.linearize_index [%arg0, %c0, %arg1] by (3, 1, %arg2) : index
+  return %ret : index
+}
+
+// -----
+
+// CHECK-LABEL: @linearize_all_zero_unit_basis
+// CHECK: arith.constant 0 : index
+// CHECK-NOT: affine.linearize_index
+func.func @linearize_all_zero_unit_basis() -> index {
+  %c0 = arith.constant 0 : index
+  %ret = affine.linearize_index [%c0, %c0] by (1, 1) : index
+  return %ret : index
+}
diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index 869ea712bb3690..2996194170900f 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -548,6 +548,22 @@ func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) {
 
 // -----
 
+func.func @linearize(%idx: index, %basis0: index, %basis1 :index) -> index {
+  // expected-error at +1 {{'affine.linearize_index' op should be passed an index for each basis element}}
+  %0 = affine.linearize_index [%idx] by (%basis0, %basis1) : index
+  return %0 : index
+}
+
+// -----
+
+func.func @linearize_empty() -> index {
+  // expected-error at +1 {{'affine.linearize_index' op basis should not be empty}}
+  %0 = affine.linearize_index [] by () : index
+  return %0 : index
+}
+
+// -----
+
 func.func @dynamic_dimension_index() {
   "unknown.region"() ({
     %idx = "unknown.test"() : () -> (index)
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index 19ae1584842aea..891415a8bd92b3 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -275,3 +275,19 @@ func.func @delinearize(%linear_idx: index, %basis0: index, %basis1 :index) -> (i
   %1:2 = affine.delinearize_index %linear_idx into (%basis0, %basis1) : index, index
   return %1#0, %1#1 : index, index
 }
+
+// -----
+
+// CHECK-LABEL: func @linearize
+func.func @linearize(%index0: index, %index1: index, %basis0: index, %basis1 :index) -> index {
+  // CHECK: affine.linearize_index [%{{.+}}, %{{.+}}] by (%{{.+}}, %{{.+}}) : index
+  %1 = affine.linearize_index [%index0, %index1] by (%basis0, %basis1) : index
+  return %1 : index
+}
+
+// CHECK-LABEL: @linearize_mixed
+func.func @linearize_mixed(%index0: index, %index1: index, %index2: index, %basis1: index) -> index {
+  // CHECK: affine.linearize_index disjoint [%{{.+}}, %{{.+}}, %{{.+}}] by (2, %{{.+}}, 3) : index
+  %1 = affine.linearize_index disjoint [%index0, %index1, %index2] by (2, %basis1, 3) : index
+  return %1 : index
+}

>From 0b396e54fc28993edc698eef6f9e438a46932328 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 4 Nov 2024 12:26:36 -0600
Subject: [PATCH 2/4] Fixups and minor changes from code review

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/include/mlir/Dialect/Affine/IR/AffineOps.td            | 6 +++---
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp                    | 4 ++--
 mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp | 6 +++---
 3 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index c49aa8281f49d1..df1944acd9ae39 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1109,7 +1109,7 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     The `affine.linearize_index` operation takes a sequence of index values and a
     basis of the same length and linearizes the indices using that basis.
 
-    That is, for indices %idx_1 through %i_N and basis elements b_1 through b_N,
+    That is, for indices `%idx_1` through `%i_N` and basis elements `b_1` through `b_N`,
     it computes
 
     ```
@@ -1124,13 +1124,13 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
 
     Example:
 
-    ```
+    ```mlir
     %linear_index = affine.delinearize_index [%index_0, %index_1, %index_2] (16, 224, 224) : index
     ```
 
     In the above example, `%linear_index` conceptually holds the following:
 
-    ```
+    ```mlir
     #map = affine_map<()[s0, s1, s2] -> (s0 * 50176 + s1 * 224 + s2)>
     %linear_index = affine.apply #map()[%index_0, %index_1, %index_2]
     ```
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 2d3b5a80df4b0e..fb56eefb0dd6d3 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4723,8 +4723,8 @@ namespace {
 ///
 /// Alternatively, if `%x` in the above is a known constant 0, remove it even if
 /// the operation isn't asserted to be `disjoint`.
-struct DropLinearizeUnitComponentsIfDisjointOrZero
-    : public OpRewritePattern<affine::AffineLinearizeIndexOp> {
+struct DropLinearizeUnitComponentsIfDisjointOrZero final
+    : OpRewritePattern<affine::AffineLinearizeIndexOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index c96c188cbc89fc..67cf67865e18cb 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -48,9 +48,9 @@ struct LowerDelinearizeIndexOps
 
 /// Lowers `affine.linearize_index` into a sequence of multiplications and
 /// additions.
-struct LowerLinearizeIndexOps
-    : public OpRewritePattern<AffineLinearizeIndexOp> {
-  using OpRewritePattern<AffineLinearizeIndexOp>::OpRewritePattern;
+struct LowerLinearizeIndexOps final
+    : OpRewritePattern<AffineLinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
     SmallVector<OpFoldResult> multiIndex =

>From b6f5cba4bfc4ba71f08d44678033ff6251731127 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 4 Nov 2024 18:40:26 +0000
Subject: [PATCH 3/4] Review comments that weren't suggestions

---
 mlir/include/mlir/Dialect/Affine/IR/AffineOps.td | 10 +++++-----
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp         |  3 +++
 2 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index df1944acd9ae39..0c214120b71892 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1109,7 +1109,7 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     The `affine.linearize_index` operation takes a sequence of index values and a
     basis of the same length and linearizes the indices using that basis.
 
-    That is, for indices `%idx_1` through `%i_N` and basis elements `b_1` through `b_N`,
+    That is, for indices `%idx_1` through `%idx_N` and basis elements `b_1` through `b_N`,
     it computes
 
     ```
@@ -1117,21 +1117,21 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     ```
 
     If the `disjoint` property is present, this is an optimization hint that,
-    for all i, 0 <= %idx_i < B_i - that is, no index affects any other index,
-    except that %idx_0 may be negative to make the index as a whole negative.
+    for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
+    except that `%idx_0` may be negative to make the index as a whole negative.
 
     Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`.
 
     Example:
 
     ```mlir
-    %linear_index = affine.delinearize_index [%index_0, %index_1, %index_2] (16, 224, 224) : index
+    %linear_index = affine.linearize_index [%index_0, %index_1, %index_2] (2, 3, 5) : index
     ```
 
     In the above example, `%linear_index` conceptually holds the following:
 
     ```mlir
-    #map = affine_map<()[s0, s1, s2] -> (s0 * 50176 + s1 * 224 + s2)>
+    #map = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)>
     %linear_index = affine.apply #map()[%index_0, %index_1, %index_2]
     ```
   }];
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index fb56eefb0dd6d3..3c0d78c20d3b76 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4700,8 +4700,10 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
 LogicalResult AffineLinearizeIndexOp::verify() {
   if (getStaticBasis().empty())
     return emitOpError("basis should not be empty");
+
   if (getMultiIndex().size() != getStaticBasis().size())
     return emitOpError("should be passed an index for each basis element");
+
   auto dynamicMarkersCount =
       llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
   if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
@@ -4709,6 +4711,7 @@ LogicalResult AffineLinearizeIndexOp::verify() {
         "mismatch between dynamic and static basis (kDynamic marker but no "
         "corresponding dynamic basis entry) -- this can only happen due to an "
         "incorrect fold/rewrite");
+
   return success();
 }
 

>From 7f621e3ad6433aa69ec9c13fe588476fb7307f99 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 4 Nov 2024 18:56:28 +0000
Subject: [PATCH 4/4] Clang-format

---
 mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 67cf67865e18cb..5903c363dc55db 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -48,8 +48,7 @@ struct LowerDelinearizeIndexOps
 
 /// Lowers `affine.linearize_index` into a sequence of multiplications and
 /// additions.
-struct LowerLinearizeIndexOps final
-    : OpRewritePattern<AffineLinearizeIndexOp> {
+struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {



More information about the Mlir-commits mailing list