-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathsimplex_projection.py
36 lines (29 loc) · 1.03 KB
/
simplex_projection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def simplex_projection(v, b=1):
"""Projection vectors to the simplex domain
Implemented according to the paper: Efficient projections onto the
l1-ball for learning in high dimensions, John Duchi, et al. ICML 2008.
Implementation Time: 2011 June 17 by Bin@libin AT pmail.ntu.edu.sg
Optimization Problem: min_{w}\| w - v \|_{2}^{2}
s.t. sum_{i=1}^{m}=z, w_{i}\geq 0
Input: A vector v \in R^{m}, and a scalar z > 0 (default=1)
Output: Projection vector w
:Example:
>>> proj = simplex_projection([.4 ,.3, -.4, .5])
>>> print proj
array([ 0.33333333, 0.23333333, 0. , 0.43333333])
>>> print proj.sum()
1.0
Original matlab implementation: John Duchi ([email protected])
Python-port: Copyright 2012 by Thomas Wiecki ([email protected]).
"""
v = np.asarray(v)
p = len(v)
# Sort v into u in descending order
v = (v > 0) * v
u = np.sort(v)[::-1]
sv = np.cumsum(u)
rho = np.where(u > (sv - b) / np.arange(1, p+1))[0][-1]
theta = np.max([0, (sv[rho] - b) / (rho+1)])
w = (v - theta)
w[w<0] = 0
return w