From 2d9eabf07e74abf9defa2964f8710232f246dba8 Mon Sep 17 00:00:00 2001 From: Maximilian Schuele Date: Thu, 30 Nov 2023 12:21:13 +0100 Subject: [PATCH] limit for mnist --- mnist/duckdb_mnist.py | 15 ++++++++------- mnist/mnist_sql92_bench.sh | 3 ++- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/mnist/duckdb_mnist.py b/mnist/duckdb_mnist.py index 7ac2616..cb3a357 100644 --- a/mnist/duckdb_mnist.py +++ b/mnist/duckdb_mnist.py @@ -4,10 +4,10 @@ from duckdb.typing import * rep=1 -limit=6000 +loadlimit=6000 sizes=[2000,200,20,2] attss=[20] -learningrate=0.01 +learningrate=0.2 createschema = ''' drop table if exists img; @@ -23,7 +23,7 @@ copy mnist from './mnist_train.csv' delimiter ',' HEADER CSV; insert into mnist2 (select row_number() over (), * from mnist limit {}); insert into one_hot(select n.i, n.j, coalesce(i.v,0), i.v from (select id,label+1 as species,1 as v from mnist2) i right outer join (select a.a as i, b.b as j from (select generate_series as a from generate_series(1,{})) a, (select generate_series as b from generate_series(1,10)) b) n on n.i=i.id and n.j=i.species order by i,j); -'''.format(limit,limit) +''' for i in range(1,785): loadmnist += ', pixel{} float'.format(i) loadmnist2 += ', pixel{} float'.format(i) @@ -48,7 +48,7 @@ ), a_xh(i,j,v) as ( SELECT m.i, n.j, 1/(1+exp(-SUM (m.v*n.v))) FROM img AS m INNER JOIN w_now AS n ON m.j=n.i - WHERE n.id=0 and n.iter=(select max(iter) from w_now) -- w_xh + WHERE m.i < {} and n.id=0 and n.iter=(select max(iter) from w_now) -- w_xh GROUP BY m.i, n.j ), a_ho(i,j,v) as ( SELECT m.i, n.j, 1/(1+exp(-SUM (m.v*n.v))) --sig(SUM (m.v*n.v)) @@ -72,6 +72,7 @@ ), d_w(id,i,j,v) as ( SELECT 0, m.j as i, n.j, (SUM (m.v*n.v)) FROM img AS m INNER JOIN d_xh AS n ON m.i=n.i + WHERE m.i < {} GROUP BY m.j, n.j union SELECT 1, m.j as i, n.j, (SUM (m.v*n.v)) @@ -110,12 +111,12 @@ def benchmark(atts,limit,iterations,learning_rate): duckdb.sql(createschema) duckdb.sql(loadmnist) duckdb.sql(loadmnist2) - duckdb.sql(loadmnistrel) + duckdb.sql(loadmnistrel.format(limit,limit)) duckdb.sql(weights.format(784,atts,atts,10)) loadtime = datetime.now() start = datetime.now() for i in range(rep): - result = duckdb.sql(train.format(learning_rate,iterations) + labelmax).fetchall() + result = duckdb.sql(train.format(limit,limit,learning_rate,iterations) + labelmax).fetchall() time=(datetime.now() - start).total_seconds()/rep #name,atts,limit,lr,iter,execution_time,accuracy print("DuckDB-SQL-92,{},{},{},{},{},{}".format(atts,limit,learning_rate,iterations,time,result[0][0])) @@ -123,5 +124,5 @@ def benchmark(atts,limit,iterations,learning_rate): print("name,atts,limit,lr,iter,execution_time,accuracy") for atts in attss: for size in sizes: - iterations=int(60/size) + iterations=int(loadlimit/size) benchmark(atts,size,iterations,learningrate) diff --git a/mnist/mnist_sql92_bench.sh b/mnist/mnist_sql92_bench.sh index 1b49f62..20edeff 100755 --- a/mnist/mnist_sql92_bench.sh +++ b/mnist/mnist_sql92_bench.sh @@ -46,7 +46,7 @@ with recursive w (iter,id,i,j,v) as ( ), a_xh(i,j,v) as ( SELECT m.i, n.j, 1/(1+exp(-SUM (m.v*n.v))) FROM (select * from img) AS m INNER JOIN w_now AS n ON m.j=n.i - WHERE n.id=0 and n.iter=(select max(iter) from w_now) -- w_xh + WHERE m.i < $limit and n.id=0 and n.iter=(select max(iter) from w_now) -- w_xh GROUP BY m.i, n.j ), a_ho(i,j,v) as ( SELECT m.i, n.j, 1/(1+exp(-SUM (m.v*n.v))) @@ -70,6 +70,7 @@ with recursive w (iter,id,i,j,v) as ( ), d_w(id,i,j,v) as ( SELECT 0, m.j as i, n.j, (SUM (m.v*n.v)) FROM (select * from img) AS m INNER JOIN d_xh AS n ON m.i=n.i + WHERE m.i < $limit GROUP BY m.j, n.j union SELECT 1, m.j as i, n.j, (SUM (m.v*n.v))