Skip to content

Commit

Permalink
limit for mnist
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxEmanuel committed Nov 30, 2023
1 parent e982328 commit 2d9eabf
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
15 changes: 8 additions & 7 deletions mnist/duckdb_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from duckdb.typing import *

rep=1
limit=6000
loadlimit=6000
sizes=[2000,200,20,2]
attss=[20]
learningrate=0.01
learningrate=0.2

createschema = '''
drop table if exists img;
Expand All @@ -23,7 +23,7 @@
copy mnist from './mnist_train.csv' delimiter ',' HEADER CSV;
insert into mnist2 (select row_number() over (), * from mnist limit {});
insert into one_hot(select n.i, n.j, coalesce(i.v,0), i.v from (select id,label+1 as species,1 as v from mnist2) i right outer join (select a.a as i, b.b as j from (select generate_series as a from generate_series(1,{})) a, (select generate_series as b from generate_series(1,10)) b) n on n.i=i.id and n.j=i.species order by i,j);
'''.format(limit,limit)
'''
for i in range(1,785):
loadmnist += ', pixel{} float'.format(i)
loadmnist2 += ', pixel{} float'.format(i)
Expand All @@ -48,7 +48,7 @@
), a_xh(i,j,v) as (
SELECT m.i, n.j, 1/(1+exp(-SUM (m.v*n.v)))
FROM img AS m INNER JOIN w_now AS n ON m.j=n.i
WHERE n.id=0 and n.iter=(select max(iter) from w_now) -- w_xh
WHERE m.i < {} and n.id=0 and n.iter=(select max(iter) from w_now) -- w_xh
GROUP BY m.i, n.j
), a_ho(i,j,v) as (
SELECT m.i, n.j, 1/(1+exp(-SUM (m.v*n.v))) --sig(SUM (m.v*n.v))
Expand All @@ -72,6 +72,7 @@
), d_w(id,i,j,v) as (
SELECT 0, m.j as i, n.j, (SUM (m.v*n.v))
FROM img AS m INNER JOIN d_xh AS n ON m.i=n.i
WHERE m.i < {}
GROUP BY m.j, n.j
union
SELECT 1, m.j as i, n.j, (SUM (m.v*n.v))
Expand Down Expand Up @@ -110,18 +111,18 @@ def benchmark(atts,limit,iterations,learning_rate):
duckdb.sql(createschema)
duckdb.sql(loadmnist)
duckdb.sql(loadmnist2)
duckdb.sql(loadmnistrel)
duckdb.sql(loadmnistrel.format(limit,limit))
duckdb.sql(weights.format(784,atts,atts,10))
loadtime = datetime.now()
start = datetime.now()
for i in range(rep):
result = duckdb.sql(train.format(learning_rate,iterations) + labelmax).fetchall()
result = duckdb.sql(train.format(limit,limit,learning_rate,iterations) + labelmax).fetchall()
time=(datetime.now() - start).total_seconds()/rep
#name,atts,limit,lr,iter,execution_time,accuracy
print("DuckDB-SQL-92,{},{},{},{},{},{}".format(atts,limit,learning_rate,iterations,time,result[0][0]))

print("name,atts,limit,lr,iter,execution_time,accuracy")
for atts in attss:
for size in sizes:
iterations=int(60/size)
iterations=int(loadlimit/size)
benchmark(atts,size,iterations,learningrate)
3 changes: 2 additions & 1 deletion mnist/mnist_sql92_bench.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ with recursive w (iter,id,i,j,v) as (
), a_xh(i,j,v) as (
SELECT m.i, n.j, 1/(1+exp(-SUM (m.v*n.v)))
FROM (select * from img) AS m INNER JOIN w_now AS n ON m.j=n.i
WHERE n.id=0 and n.iter=(select max(iter) from w_now) -- w_xh
WHERE m.i < $limit and n.id=0 and n.iter=(select max(iter) from w_now) -- w_xh
GROUP BY m.i, n.j
), a_ho(i,j,v) as (
SELECT m.i, n.j, 1/(1+exp(-SUM (m.v*n.v)))
Expand All @@ -70,6 +70,7 @@ with recursive w (iter,id,i,j,v) as (
), d_w(id,i,j,v) as (
SELECT 0, m.j as i, n.j, (SUM (m.v*n.v))
FROM (select * from img) AS m INNER JOIN d_xh AS n ON m.i=n.i
WHERE m.i < $limit
GROUP BY m.j, n.j
union
SELECT 1, m.j as i, n.j, (SUM (m.v*n.v))
Expand Down

0 comments on commit 2d9eabf

Please sign in to comment.