-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
jl_kmeans.clj
404 lines (360 loc) · 17.4 KB
/
jl_kmeans.clj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
(ns kmeans-mnist.jl-kmeans
"KMeans implementation combining Clojure and Julia"
(:require [libjulia-clj.julia :refer [jl] :as julia]
[tech.v3.datatype :as dtype]
[tech.v3.datatype.argops :as argops]
[tech.v3.datatype.errors :as errors]
[tech.v3.datatype.functional :as dfn]
[tech.v3.tensor :as dtt]
[tech.v3.parallel.for :as pfor]
[tech.v3.datatype.reductions :as reductions]
[tech.v3.libs.buffered-image :as bufimg]
[clojure.java.io :as io]
[clojure.tools.logging :as log])
(:import [java.util Random HashMap]
[java.util.function BiFunction BiConsumer]
[tech.v3.datatype ArrayHelpers IndexReduction
IndexReduction$IndexedBiFunction]))
(set! *warn-on-reflection* true)
(set! *unchecked-math* :warn-on-boxed)
;; If you enable threads in Julia JVM signal forwarding must be enabled
(defonce init* (delay (julia/initialize! {:n-threads -1
:optimization-level 3})))
@init*
;; Use to make sure memory is being released as it should be
(julia/set-julia-gc-root-log-level! :info)
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; Loading initial julia symbols
(do
(def loader (jl (slurp (io/resource "kmeans.jl"))))
(def kmeans-next-centroid (jl "kmeans_next_centroid"))
(def assign-centroids (jl "assign_centroids"))
(def assign-centroids-imr (jl "assign_centroids_imr"))
(def assign-calc-centroids (jl "assign_calc_centroids"))
(def score-kmeans (jl "score_kmeans"))
(def order-data-labels (jl "order_data_labels"))
(def per-label-infer (jl "per_label_infer"))
)
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; Clojure versions of some functions for performance comparison
(defmacro distance-squared
[dataset centroids row-idx centroid-idx ncols]
`(double
(loop [col-idx# 0
sum# 0.0]
(if (< col-idx# ~ncols)
(let [diff# (- (.ndReadDouble ~dataset ~row-idx col-idx#)
(.ndReadDouble ~centroids ~centroid-idx col-idx#))]
(recur (unchecked-inc col-idx#) (+ sum# (* diff# diff#))))
sum#))))
(defn jvm-next-centroid
[dataset centroids idx distances scan-distances]
(let [dataset (dtt/as-tensor dataset)
centroids (dtt/as-tensor centroids)
distances (dtype/->buffer distances)
scan-distances (dtype/->buffer scan-distances)
centroid-idx (long idx)
[nrows ncols] (dtype/shape dataset)
nrows (long nrows)
ncols (long ncols)]
(pfor/parallel-for
row-idx nrows
(let [sum (distance-squared dataset centroids row-idx centroid-idx ncols)]
(when (< sum (.readDouble distances row-idx))
(.writeDouble distances row-idx sum))))
(.writeDouble scan-distances 0 (.readDouble distances 0))
(loop [idx 1]
(when (< idx nrows)
(.writeDouble scan-distances idx
(+ (.readDouble scan-distances (unchecked-dec idx))
(.readDouble distances idx)))
(recur (unchecked-inc idx))))))
(defn- seed->random
^Random [seed]
(cond
(number? seed)
(Random. (int seed))
(instance? Random seed)
seed
(nil? seed)
(Random.)
:else
(errors/throwf "Unrecognized seed type: %s" seed)))
(defn- choose-centroids++
"Implementation of (sort of) kmeans++ center choosing algorithm. As opposed to
sorting the distances every centroid, we calculate a cumulative summation vector and
do a binary search within this vector. We take the next item after the potential
insert position of the randomly generated target distance as this item is nearly
always going to be larger than the item pointed to at insert position."
[dataset n-centroids {:keys [seed
distance-fn]}]
;;Note julia is col major but datatype is row major
(let [[n-rows n-cols] (dtype/shape dataset)
;;Remember julia is column major, so shape arguments are reversed
centroids (julia/new-array [n-cols n-centroids] :float64)]
(julia/with-stack-context
(let [centroid-tens (dtt/as-tensor centroids)
ds-tens (dtt/as-tensor dataset)
random (seed->random seed)
n-rows (long n-rows)
distances (julia/new-array [n-rows] :float64)
_ (dtype/set-constant! distances Double/MAX_VALUE)
scan-distances (julia/new-array [n-rows] :float64)
scan-dist-tens (dtt/as-tensor scan-distances)
initial-seed-idx (.nextInt random (int n-rows))
_ (dtt/mset! centroid-tens 0 (dtt/mget ds-tens initial-seed-idx))
n-centroids (long n-centroids)
last-idx (dec n-rows)]
(dotimes [idx (dec n-centroids)]
(distance-fn dataset centroids idx distances scan-distances)
(let [next-flt (.nextDouble ^Random random)
n-rows (dtype/ecount distances)
distance-sum (double (scan-dist-tens (dec n-rows)))
target-amt (* next-flt distance-sum)
next-center-idx (min last-idx
;;You want the one just *after* where you could
;;safely insert the distance as the next distance is
;;likely much larger than the current distance and
;;thus your probability of getting a vector that
;;that is a large distance away than any known
;;vectors is higher
(inc (argops/binary-search scan-dist-tens
target-amt)))]
(dtt/mset! centroid-tens (inc idx) (dtt/mget ds-tens next-center-idx))))
centroids))))
(defrecord AggReduceContext [^doubles center
^longs n-rows])
(defn- new-center-reduction
^IndexReduction [dataset]
(let [[_nrows n-cols] (dtype/shape dataset)
dataset (dtype/->buffer dataset)
make-reduce-context #(->AggReduceContext (double-array n-cols)
(long-array 1))
n-cols (long n-cols)]
(reify IndexReduction
(reduceIndex [this batch ctx row-idx]
(let [^AggReduceContext ctx (or ctx (make-reduce-context))
row-off (* n-cols row-idx)]
(dotimes [col-idx n-cols]
(ArrayHelpers/accumPlus ^doubles (.center ctx) col-idx
(.readDouble dataset (+ row-off col-idx))))
(ArrayHelpers/accumPlus ^longs (.n-rows ctx) 0 1)
ctx))
(reduceReductions [this lhsCtx rhsCtx]
(let [^AggReduceContext lhsCtx lhsCtx
^AggReduceContext rhsCtx rhsCtx]
(dotimes [col-idx n-cols]
(ArrayHelpers/accumPlus ^doubles (.center lhsCtx) col-idx
(aget ^doubles (.center rhsCtx) col-idx)))
(ArrayHelpers/accumPlus ^longs (.n-rows lhsCtx) 0
(aget ^longs (.n-rows rhsCtx) 0))
lhsCtx)))))
;; It was fastest to use a combined kmeans iteration. Julia to assign indexes
;; and the JVM to sum rows into new centroids
(defn jvm-assign-centers-from-centroid-indexes
[dataset n-centers center-indexes]
(let [reducer (new-center-reduction dataset)
center-map (reductions/ordered-group-by-reduce reducer nil center-indexes)
new-centroids (dtt/new-tensor [n-centers (second (dtype/shape dataset))])]
(.forEach ^HashMap center-map
(reify BiConsumer
(accept [this center-idx reduce-context]
(let [^doubles center (:center reduce-context)
^longs n-rows (:n-rows reduce-context)]
(dtt/mset! new-centroids center-idx
(dfn// center (aget n-rows 0)))))))
new-centroids))
(comment
(do
(def src-image (bufimg/load "data/jen.jpg"))
(def img-height (first (dtype/shape src-image)))
(def img-width (second (dtype/shape src-image)))
(def nrows (* img-width img-height))
(def ncols (/ (dtype/ecount src-image)
nrows))
(def dataset (-> (dtt/reshape src-image [nrows ncols])
(julia/->array)))
(def centroid-indexes (-> (dtt/new-tensor [nrows] :datatype :int32)
(julia/->array)))
)
(def centroids (time (choose-centroids++
dataset 5 {:seed 5 :distance-fn kmeans-next-centroid})))
;; 285ms
(def centroids (time (choose-centroids++
dataset 5 {:seed 5 :distance-fn jvm-next-centroid})))
;;1163ms
(def score (time (assign-centroids dataset centroids centroid-indexes)))
;; 103ms
(def score (time (assign-centroids-imr dataset centroids centroid-indexes)))
;; 103ms
(def score (time (jvm-assign-centroids dataset centroids centroid-indexes)))
;; 775ms
(def jvm-centroids
(time (do
(assign-centroids-imr dataset centroids centroid-indexes)
(jvm-assign-centers-from-centroid-indexes dataset
(count centroids)
centroid-indexes))))
;; 200ms
(def jl-centroids (time (assign-calc-centroids dataset centroids new-centroids)))
;; 400ms -> 1700ms, varying
(jl "ccall(:jl_in_threaded_region, Cint, ())")
)
(defn kmeans++
"Find K cluster centroids via kmeans++ center initialization
followed by Lloyds algorithm.
Dataset must be a matrix (2d tensor).
* `dataset` - 2d matrix of numeric datatype.
* `n-centroids` - How many centroids to find.
Returns map of:
* `:centroids` - 2d tensor of double centroids
* `:centroid-indexes` - 1d integer vector of assigned center indexes.
* `:iteration-scores` - n-iters+1 length array of mean squared error scores container
the scores from centroid assigned up to the score when the algorithm
terminates.
Options:
* `:minimal-improvement-threshold` - defaults to 0.01 - algorithm terminates if
(1.0 - error(n-1)/error(n-2)) < error-diff-threshold. When Zero means algorithm will
always train to max-iters.
* `:n-iters` - defaults to 100 - Max number of iterations, algorithm terminates
if `(>= iter-idx n-iters).
* `:rand-seed` - integer or implementation of `java.util.Random`."
[dataset n-centroids & [{:keys [n-iters rand-seed
minimal-improvement-threshold]
:or {minimal-improvement-threshold 0.01}
:as options}]]
(errors/when-not-error
(== 2 (dtype/ecount (dtype/shape dataset)))
"Dataset must be a matrix of rank 2")
(let [[n-rows n-cols] (dtype/shape dataset)
ds-dtype (dtype/elemwise-datatype dataset)
n-iters (long (or n-iters 100))
minimal-improvement-threshold (double (or minimal-improvement-threshold
0.011))]
(log/infof "Choosing n-centroids %d with %f improvement threshold and max %d iters"
n-centroids minimal-improvement-threshold n-iters)
(julia/with-stack-context
(let [dataset (julia/->array dataset)
centroids (if (number? n-centroids)
(choose-centroids++ dataset n-centroids
{:seed rand-seed
:distance-fn kmeans-next-centroid})
(do
(errors/when-not-error
(== 2 (count (dtype/shape n-centroids)))
"Centroids must be rank 2")
(julia/->array n-centroids)))
centroid-indexes (julia/->array (dtt/new-tensor [n-rows] :datatype :int32))
dec-n-iters (dec n-iters)
n-rows (long n-rows)
scores (if-not (== 0 n-iters)
(loop [iter-idx 0
last-score 0.0
scores []]
(let [score (assign-centroids-imr dataset centroids
centroid-indexes)
new-centroids (jvm-assign-centers-from-centroid-indexes
dataset n-centroids centroid-indexes)
score (/ (double score) n-rows)
rel-score (if-not (== 0.0 last-score)
(- 1.0 (/ score last-score))
1.0)]
(dtype/copy! new-centroids centroids)
(log/infof "Iteration %d out of %d - relative improvement %f->%f=%f"
iter-idx n-iters last-score score rel-score)
(if (and (< iter-idx dec-n-iters)
(not= 0.0 score)
(> rel-score minimal-improvement-threshold))
(recur (unchecked-inc iter-idx) score (conj scores score))
scores)))
[])
final-score (score-kmeans dataset centroids)]
;;Clone data back into jvm land to escape the resource context
{:centroids (dtt/clone centroids)
:centroid-indexes (dtt/clone centroid-indexes)
:iteration-scores (vec (concat scores [(/ (double final-score)
(double n-rows))]))}))))
(comment
(do
(def src-image (bufimg/load "data/jen.jpg"))
(def img-height (first (dtype/shape src-image)))
(def img-width (second (dtype/shape src-image)))
(def nrows (* img-width img-height))
(def ncols (/ (dtype/ecount src-image) nrows))
(def dataset (-> (dtt/reshape src-image [nrows ncols]))))
(def img-data (time (kmeans++ dataset 5 {:rand-seed 5})))
;;2716ms
)
(defn- concatenate-results
"Given a sequence of maps, return one result map with
tensors with one extra dimension. Works when every result has the
same length."
[result-seq]
(when (seq result-seq)
(->> (first result-seq)
(map (fn [[k v]]
[k (dtt/->tensor (mapv k result-seq)
:datatype (dtype/elemwise-datatype v))]))
(into {}))))
(defn train-per-label
"Given a dataset along with per-row integer labels, train N per-label
kmeans centroids returning a model which you can use can use with predict-per-label."
[data labels n-per-label & [{:keys [input-ordered?]
:as options}]]
(julia/with-stack-context
(when-not (empty? labels)
;;Organize data per-label
(let [n-per-label (long n-per-label)
ds-dtype (dtype/elemwise-datatype data)
[data labels] (if input-ordered?
[(julia/->array data) (julia/->array labels)]
;;Order data and labels by increasing index
(order-data-labels (julia/->array data)
(julia/->array labels)))
[n-rows n-cols] (dtype/shape data)
labels (->> (argops/arggroup labels)
(into {})
(sort-by first)
;;arggroup be default uses an 'ordered' algorithm that guarantees
;;the result index list is ordered.
(mapv (fn [[label idx-list]]
[label
[(first idx-list) (last idx-list)]])))
n-labels (count labels)]
(->> labels
(map (fn [[label [^long idx-start ^long past-idx-end]]]
;;Tensor selection from contiguous data of a range with an increment of 1
;;is guaranteed to produce contiguous data
(log/infof "Training centroids for label %s" label)
(let [{:keys [centroids centroid-indexes iteration-scores]}
(-> (dtt/select data (range idx-start past-idx-end))
(kmeans++ n-per-label options))]
{:centroids centroids
:labels label
:iteration-scores (last iteration-scores)})))
(concatenate-results)
(merge {:kmeans-type :n-per-label}))))))
(defn predict-per-label
"Using a per-label `model`, find the nearest centroid to each row
and return a 1d tensor of the predicted label.
Returns:
* `:label-indexes` - int32 assigned indexes for each row in the dataset."
[dataset model]
(julia/with-stack-context
(let [{:keys [centroids labels]} model
[n-labels n-per-label n-cols] (dtype/shape centroids)]
;;Eventually we will have a resource context here
(let [[n-labels n-per-label n-cols] (dtype/shape centroids)
[n-rows n-data-cols] (dtype/shape dataset)
_ (errors/when-not-errorf
(= n-cols n-data-cols)
"Data (%d), model (%d) have different feature counts"
n-data-cols n-cols)
dataset (julia/->array dataset)
n-centroids (* (long n-labels)
(long n-per-label))
centroids (-> (dtt/reshape centroids [n-centroids n-cols])
(julia/->array))
indexes (julia/new-array [n-rows] :int32)]
(per-label-infer dataset centroids n-labels indexes)
{:label-indexes (dtype/clone indexes)}))))