1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python
# coding: utf-8
import random
import heapq
from itertools import chain,product
from pprint import pprint
from copy import copy

"""
Answer:
    16PS - 10 GS
    7PS  - 18 MS
    6MS  - 20 GS
"""

eleves = { "PS" : 23, "MS" : 24, "GS": 30 }

#sample output:
classes_example = [ {"PS": 6, "MS" : 22 }, {"GS":15,"PS":17}, {"GS":15,"MS":2} ]
classes_answer = [ {"PS": 16, "GS" : 10 }, {"PS":7,"MS":18}, {"MS":6,"GS":20} ]

def classe_size(classe):
    return sum(classe.values())

def evaluate(classes):
    penalty = 0

    # Favor equal-sized classes
    tailles = [classe_size(c) for c in classes]
    taille_m = sum(tailles) / len(tailles)
    #print(taille_m)
    penalty += sum((t-taille_m)**2 for t in tailles)

    # Penalize small groups
    for c in classes:
        for t in c.values(): #Age-group sizes
            if t<6:
                penalty+=(6-t)**2
        # Penalize triple levels
        if len(c) > 2:
            penalty += 50
        elif len(c) == 2:
            penalty += 10
    return penalty



# Enumerate the number of one classes you can create assuming there will be n in total
min_n_per_group = 6
def enumerate_classes(eleves,n_classes_total,min_n,max_n,levels=None,already_in_class=0):
    if levels is None:
        levels = sorted(eleves.keys())
    current = levels[0]
    n_current = eleves[levels[0]]
    try:
        assert(eleves[levels[0]] >= min_n_per_group)
    except:
        raise
    # Without loss of generality, assume the first class always has some of the first group
    #print(current,min_n_per_group,1+min(n_current,max_n))
    for n in chain(range(0,1),range(min_n_per_group,1+min(n_current,max_n))):
        if n_classes_total == 1 and n_current != n:
            continue
        if n_current != n and n_current - n < min_n_per_group: # can't leave a small group for the other classes
            continue
        #print(levels,eleves,current,n_current,n)
        #print(levels[0],n,sum(eleves[l] for l in eleves.keys() if l != current))
        # Too few pupils in the class
        #print(levels,eleves,current,n_current,n)
        #print(already_in_class,n)
        if min_n > already_in_class + n + sum(eleves[l] for l in levels if l != current):
            continue
        output = {}
        if n != 0:
            output = {current : n}
        if len(levels) == 1:
            yield output
        else:
            remaining = copy(eleves)
            if n < n_current:
                remaining[current] = n_current - n
            elif n == n_current:
                del remaining[current]
            else:
                assert(False)
            for pos in enumerate_classes(remaining,n_classes_total,min_n,max_n-n,levels[1:],already_in_class+n):
                yield dict(chain(output.items(),pos.items()))

def enumerate_possibilities(eleves,n_classes):
    #print(eleves,n_classes)
    for c in enumerate_classes(eleves,n_classes,20,30):
        if n_classes == 1:
            #print("yielding")
            yield (c,)
        else:
            remaining = copy(eleves)
            for level,n in c.items():
                if remaining[level] - n == 0:
                    del remaining[level]
                elif remaining[level] - n < 0:
                    assert(False)
                else:
                    remaining[level] -= n
            for pos in enumerate_possibilities(remaining,n_classes-1):
                yield (c,) + pos
#print(evaluate(classes_answer))
#qdsf

#print(eleves)
#for c in enumerate_classes(eleves,3,20,30):
#    print(c)
#sdf
        
#el = {'MS': 7, 'GS': 8, 'PS': 6}
#for c in enumerate_classes(el,1,20,30):
#    print(c)

#el = {'MS': 17, 'GS': 18, 'PS': 16}
#for c in enumerate_classes(el,3,20,30):
#    print(c)

n=0
all_pos = set()
for pos in enumerate_possibilities(eleves,3):
    n+=1
    all_pos.add((evaluate(pos),frozenset(frozenset(tuple(c.items())) for c in pos)))
    #pprint(pos)
#print(all_pos[0][0])
pprint(heapq.nsmallest(10,all_pos,key=lambda e:e[0]))