Refactor `Permutations` method in `EnumerableExtensions` to remove unnecessary logic and add validation for output count in tests

This commit is contained in:
Sebastian Lindemeier 2025-12-15 11:54:40 +01:00
parent acef1dd7d2
commit 149977d02f
2 changed files with 22 additions and 28 deletions

View File

@ -4,13 +4,14 @@ namespace AdventOfCode.Extensions.Tests;
public class EnumerableExtensionsTest
{
[Fact]
public void Combinations_12_2_equals_12_21()
public void Combinations_12_2_equals_12()
{
int[] data = [1, 2];
var actual = data.Combinations(2).ToArray();
Assert.Contains([1, 2], actual);
Assert.Equal(1, actual.Length);
}
[Fact]
@ -23,6 +24,7 @@ public class EnumerableExtensionsTest
Assert.Contains([1, 2], actual);
Assert.Contains([1, 3], actual);
Assert.Contains([2, 3], actual);
Assert.Equal(3, actual.Length);
}
[Fact]
@ -33,6 +35,7 @@ public class EnumerableExtensionsTest
var actual = data.Combinations(3).ToArray();
Assert.Contains([1, 2, 3], actual);
Assert.Equal(1, actual.Length);
}
[Fact]
@ -46,6 +49,7 @@ public class EnumerableExtensionsTest
Assert.Contains([1, 2], actual);
Assert.Contains([2, 1], actual);
Assert.Contains([2, 2], actual);
Assert.Equal(4, actual.Length);
}
[Fact]
@ -62,5 +66,6 @@ public class EnumerableExtensionsTest
Assert.Contains([3, 1], actual);
Assert.Contains([3, 2], actual);
Assert.Contains([3, 3], actual);
Assert.Equal(9, actual.Length);
}
}

View File

@ -41,37 +41,26 @@ public static class EnumerableExtensions
var pool = values.ToArray();
if (count > pool.Length)
yield break;
foreach (var p in InnerPermutations(pool, count))
yield return p;
foreach (var p in InnerPermutations([..pool.Reverse()], count))
yield return p;
yield break;
static IEnumerable<TValue[]> InnerPermutations(TValue[] pool, int count)
var indices = Enumerable.Repeat(0, count).ToArray();
yield return GetCombination(indices, pool);
while (true)
{
var indices = Enumerable.Repeat(0, count).ToArray();
yield return GetCombination(indices, pool);
while (true)
var idx = count - 1;
for(;idx > 0; --idx)
{
var idx = count - 1;
for(;idx >= 0; --idx)
{
var isIndexBelowMax = indices[idx] != idx + pool.Length - count;
if (isIndexBelowMax)
break;
}
if(idx < 0)
var isIndexBelowMax = indices[idx] != idx + pool.Length - count;
if (isIndexBelowMax)
break;
indices[idx] += 1;
for (var j = idx + 1; j < count; j++)
{
indices[j] = indices[j - 1];
}
yield return GetCombination(indices, pool);
}
if(indices.All(i => i >= pool.Length - 1))
yield break;
indices[idx] += 1;
for (var j = idx + 1; j < count; j++)
{
indices[j] = 0;
}
yield return GetCombination(indices, pool);
}
}