Add `CombinationsWithRepeats` method to `EnumerableExtensions` with corresponding tests

This commit is contained in:
Sebastian Lindemeier 2025-12-15 13:45:21 +01:00
parent 149977d02f
commit 0665713725
2 changed files with 67 additions and 5 deletions

View File

@ -11,7 +11,7 @@ public class EnumerableExtensionsTest
var actual = data.Combinations(2).ToArray(); var actual = data.Combinations(2).ToArray();
Assert.Contains([1, 2], actual); Assert.Contains([1, 2], actual);
Assert.Equal(1, actual.Length); Assert.Single(actual);
} }
[Fact] [Fact]
@ -35,7 +35,37 @@ public class EnumerableExtensionsTest
var actual = data.Combinations(3).ToArray(); var actual = data.Combinations(3).ToArray();
Assert.Contains([1, 2, 3], actual); Assert.Contains([1, 2, 3], actual);
Assert.Equal(1, actual.Length); Assert.Single(actual);
}
[Fact]
public void Combinations_repeated_12_2_equals_11_12_22()
{
int[] data = [1, 2];
var actual = data.CombinationsWithRepeats(2).ToArray();
Assert.Equal([[1, 1], [1, 2], [2, 2]], actual);
}
[Fact]
public void Combinations_repeated_123_2_equals_11_12_13_22_23_33()
{
int[] data = [1, 2, 3];
var actual = data.CombinationsWithRepeats(2).ToArray();
Assert.Equal([[1, 1], [1, 2], [1, 3], [2, 2], [2, 3], [3, 3]], actual);
}
[Fact]
public void Combinations_repeated_123_3_equals_many()
{
int[] data = [1, 2, 3];
var actual = data.CombinationsWithRepeats(3).ToArray();
Assert.Equal([[1, 1, 1], [1, 1, 2], [1, 1, 3], [1, 2, 2], [1, 2, 3], [1, 3, 3], [2, 2, 2], [2, 2, 3], [2, 3, 3], [3, 3, 3]], actual);
} }
[Fact] [Fact]

View File

@ -12,14 +12,15 @@ public static class EnumerableExtensions
var poolLength = pool.Length; var poolLength = pool.Length;
if (count > poolLength) if (count > poolLength)
yield break; yield break;
var indices = Enumerable.Range(0, count).ToArray(); int[] indices = [..Enumerable.Range(0, count)];
yield return GetCombination(indices, pool); yield return GetCombination(indices, pool);
while (true) while (true)
{ {
var idx = count - 1; var idx = count - 1;
for(;idx >= 0; --idx) for(;idx >= 0; --idx)
{ {
var isIndexBelowMax = indices[idx] != idx + poolLength - count; var maxPossibleIndex = idx + poolLength - count;
var isIndexBelowMax = indices[idx] != maxPossibleIndex;
if (isIndexBelowMax) if (isIndexBelowMax)
break; break;
} }
@ -27,7 +28,7 @@ public static class EnumerableExtensions
yield break; yield break;
indices[idx] += 1; indices[idx] += 1;
for (var j = idx + 1; j < count; j++) for (var j = idx + 1; j < indices.Length; j++)
{ {
indices[j] = indices[j - 1] + 1; indices[j] = indices[j - 1] + 1;
} }
@ -36,6 +37,37 @@ public static class EnumerableExtensions
} }
} }
public static IEnumerable<TValue[]> CombinationsWithRepeats<TValue>(this IEnumerable<TValue> values, int count)
{
var pool = values.ToArray();
var poolLength = pool.Length;
if (count > poolLength)
yield break;
int[] indices = [..Enumerable.Repeat(0, count)];
yield return GetCombination(indices, pool);
while (true)
{
var idx = count - 1;
for(;idx >= 0; --idx)
{
var maxPossibleIndex = poolLength - 1;
var isIndexBelowMax = indices[idx] != maxPossibleIndex;
if (isIndexBelowMax)
break;
}
if(idx < 0)
yield break;
var num = indices[idx];
for (var j = idx; j < indices.Length; j++)
{
indices[j] = num + 1;
}
yield return GetCombination(indices, pool);
}
}
public static IEnumerable<TValue[]> Permutations<TValue>(this IEnumerable<TValue> values, int count) public static IEnumerable<TValue[]> Permutations<TValue>(this IEnumerable<TValue> values, int count)
{ {
var pool = values.ToArray(); var pool = values.ToArray();