Refactor `Permutations` method in `EnumerableExtensions` to remove unnecessary logic and add validation for output count in tests

This commit is contained in:
Sebastian Lindemeier 2025-12-15 11:54:40 +01:00
parent acef1dd7d2
commit 149977d02f
2 changed files with 22 additions and 28 deletions

View File

@ -4,13 +4,14 @@ namespace AdventOfCode.Extensions.Tests;
public class EnumerableExtensionsTest public class EnumerableExtensionsTest
{ {
[Fact] [Fact]
public void Combinations_12_2_equals_12_21() public void Combinations_12_2_equals_12()
{ {
int[] data = [1, 2]; int[] data = [1, 2];
var actual = data.Combinations(2).ToArray(); var actual = data.Combinations(2).ToArray();
Assert.Contains([1, 2], actual); Assert.Contains([1, 2], actual);
Assert.Equal(1, actual.Length);
} }
[Fact] [Fact]
@ -23,6 +24,7 @@ public class EnumerableExtensionsTest
Assert.Contains([1, 2], actual); Assert.Contains([1, 2], actual);
Assert.Contains([1, 3], actual); Assert.Contains([1, 3], actual);
Assert.Contains([2, 3], actual); Assert.Contains([2, 3], actual);
Assert.Equal(3, actual.Length);
} }
[Fact] [Fact]
@ -33,6 +35,7 @@ public class EnumerableExtensionsTest
var actual = data.Combinations(3).ToArray(); var actual = data.Combinations(3).ToArray();
Assert.Contains([1, 2, 3], actual); Assert.Contains([1, 2, 3], actual);
Assert.Equal(1, actual.Length);
} }
[Fact] [Fact]
@ -46,6 +49,7 @@ public class EnumerableExtensionsTest
Assert.Contains([1, 2], actual); Assert.Contains([1, 2], actual);
Assert.Contains([2, 1], actual); Assert.Contains([2, 1], actual);
Assert.Contains([2, 2], actual); Assert.Contains([2, 2], actual);
Assert.Equal(4, actual.Length);
} }
[Fact] [Fact]
@ -62,5 +66,6 @@ public class EnumerableExtensionsTest
Assert.Contains([3, 1], actual); Assert.Contains([3, 1], actual);
Assert.Contains([3, 2], actual); Assert.Contains([3, 2], actual);
Assert.Contains([3, 3], actual); Assert.Contains([3, 3], actual);
Assert.Equal(9, actual.Length);
} }
} }

View File

@ -41,39 +41,28 @@ public static class EnumerableExtensions
var pool = values.ToArray(); var pool = values.ToArray();
if (count > pool.Length) if (count > pool.Length)
yield break; yield break;
foreach (var p in InnerPermutations(pool, count))
yield return p;
foreach (var p in InnerPermutations([..pool.Reverse()], count))
yield return p;
yield break;
static IEnumerable<TValue[]> InnerPermutations(TValue[] pool, int count)
{
var indices = Enumerable.Repeat(0, count).ToArray(); var indices = Enumerable.Repeat(0, count).ToArray();
yield return GetCombination(indices, pool); yield return GetCombination(indices, pool);
while (true) while (true)
{ {
var idx = count - 1; var idx = count - 1;
for(;idx >= 0; --idx) for(;idx > 0; --idx)
{ {
var isIndexBelowMax = indices[idx] != idx + pool.Length - count; var isIndexBelowMax = indices[idx] != idx + pool.Length - count;
if (isIndexBelowMax) if (isIndexBelowMax)
break; break;
} }
if(idx < 0) if(indices.All(i => i >= pool.Length - 1))
break; yield break;
indices[idx] += 1; indices[idx] += 1;
for (var j = idx + 1; j < count; j++) for (var j = idx + 1; j < count; j++)
{ {
indices[j] = indices[j - 1]; indices[j] = 0;
} }
yield return GetCombination(indices, pool); yield return GetCombination(indices, pool);
} }
} }
}
private static TValue[] GetCombination<TValue>(int[] innerIndices, TValue[] innerPool) => private static TValue[] GetCombination<TValue>(int[] innerIndices, TValue[] innerPool) =>
innerIndices.Select(i => innerPool[i]).ToArray(); innerIndices.Select(i => innerPool[i]).ToArray();