diff --git a/AdventOfCode.Extensions.Tests/EnumerableExtensionsTest.cs b/AdventOfCode.Extensions.Tests/EnumerableExtensionsTest.cs index 0bcf279..5635770 100644 --- a/AdventOfCode.Extensions.Tests/EnumerableExtensionsTest.cs +++ b/AdventOfCode.Extensions.Tests/EnumerableExtensionsTest.cs @@ -11,7 +11,7 @@ public class EnumerableExtensionsTest var actual = data.Combinations(2).ToArray(); Assert.Contains([1, 2], actual); - Assert.Equal(1, actual.Length); + Assert.Single(actual); } [Fact] @@ -35,7 +35,37 @@ public class EnumerableExtensionsTest var actual = data.Combinations(3).ToArray(); Assert.Contains([1, 2, 3], actual); - Assert.Equal(1, actual.Length); + Assert.Single(actual); + } + + [Fact] + public void Combinations_repeated_12_2_equals_11_12_22() + { + int[] data = [1, 2]; + + var actual = data.CombinationsWithRepeats(2).ToArray(); + + Assert.Equal([[1, 1], [1, 2], [2, 2]], actual); + } + + [Fact] + public void Combinations_repeated_123_2_equals_11_12_13_22_23_33() + { + int[] data = [1, 2, 3]; + + var actual = data.CombinationsWithRepeats(2).ToArray(); + + Assert.Equal([[1, 1], [1, 2], [1, 3], [2, 2], [2, 3], [3, 3]], actual); + } + + [Fact] + public void Combinations_repeated_123_3_equals_many() + { + int[] data = [1, 2, 3]; + + var actual = data.CombinationsWithRepeats(3).ToArray(); + + Assert.Equal([[1, 1, 1], [1, 1, 2], [1, 1, 3], [1, 2, 2], [1, 2, 3], [1, 3, 3], [2, 2, 2], [2, 2, 3], [2, 3, 3], [3, 3, 3]], actual); } [Fact] diff --git a/AdventOfCode.Extensions/EnumerableExtensions.cs b/AdventOfCode.Extensions/EnumerableExtensions.cs index 3be3644..1994e6d 100644 --- a/AdventOfCode.Extensions/EnumerableExtensions.cs +++ b/AdventOfCode.Extensions/EnumerableExtensions.cs @@ -12,14 +12,15 @@ public static class EnumerableExtensions var poolLength = pool.Length; if (count > poolLength) yield break; - var indices = Enumerable.Range(0, count).ToArray(); + int[] indices = [..Enumerable.Range(0, count)]; yield return GetCombination(indices, pool); while (true) { var idx = count - 1; for(;idx >= 0; --idx) { - var isIndexBelowMax = indices[idx] != idx + poolLength - count; + var maxPossibleIndex = idx + poolLength - count; + var isIndexBelowMax = indices[idx] != maxPossibleIndex; if (isIndexBelowMax) break; } @@ -27,7 +28,7 @@ public static class EnumerableExtensions yield break; indices[idx] += 1; - for (var j = idx + 1; j < count; j++) + for (var j = idx + 1; j < indices.Length; j++) { indices[j] = indices[j - 1] + 1; } @@ -36,6 +37,37 @@ public static class EnumerableExtensions } } + public static IEnumerable CombinationsWithRepeats(this IEnumerable values, int count) + { + var pool = values.ToArray(); + var poolLength = pool.Length; + if (count > poolLength) + yield break; + int[] indices = [..Enumerable.Repeat(0, count)]; + yield return GetCombination(indices, pool); + while (true) + { + var idx = count - 1; + for(;idx >= 0; --idx) + { + var maxPossibleIndex = poolLength - 1; + var isIndexBelowMax = indices[idx] != maxPossibleIndex; + if (isIndexBelowMax) + break; + } + if(idx < 0) + yield break; + + var num = indices[idx]; + for (var j = idx; j < indices.Length; j++) + { + indices[j] = num + 1; + } + + yield return GetCombination(indices, pool); + } + } + public static IEnumerable Permutations(this IEnumerable values, int count) { var pool = values.ToArray();