forked from neo4j-field/neo4j-arrow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
arrow_to_bq.py
137 lines (116 loc) · 4.58 KB
/
arrow_to_bq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import pyarrow as pa
from src.main.neo4j_arrow import neo4j_arrow as na
from google.cloud import bigquery
import os, sys, threading, queue
from time import time
### Config
HOST = os.environ.get('NEO4J_ARROW_HOST', 'localhost')
PORT = int(os.environ.get('NEO4J_ARROW_PORT', '9999'))
USERNAME = os.environ.get('NEO4J_USERNAME', 'neo4j')
PASSWORD = os.environ.get('NEO4J_PASSWORD', 'password')
GRAPH = os.environ.get('NEO4J_GRAPH', 'random')
PROPERTIES = [
p for p in os.environ.get('NEO4J_PROPERTIES', '').split(',')
if len(p) > 0]
TLS = len(os.environ.get('NEO4J_ARROW_TLS', '')) > 0
TLS_VERIFY = len(os.environ.get('NEO4J_ARROW_TLS_NO_VERIFY', '')) < 1
DATASET = os.environ.get('BQ_DATASET', 'neo4j_arrow')
TABLE = os.environ.get('BQ_TABLE', 'nodes')
DELETE = len(os.environ.get('BQ_DELETE_FIRST', '')) > 0
### Globals
bq_client = bigquery.Client()
data_q = queue.Queue()
job_q = queue.Queue(maxsize=8)
done_feeding = threading.Event()
### Various Load Options
parquet_options = bigquery.format_options.ParquetOptions()
parquet_options.enable_list_inference = True
bigquery.dataset = DATASET
job_config = bigquery.LoadJobConfig()
job_config.source_format = bigquery.SourceFormat.PARQUET
job_config.parquet_options = parquet_options
def upload_complete(load_job):
# print(f'bq upload complete ({load_job.job_id})')
q.task_done()
def write_to_bigquery(writer_id, table, job_config):
"""Convert a PyArrow Table to a Parquet file and load into BigQuery"""
writer = pa.BufferOutputStream()
pa.parquet.write_table(table, writer, use_compliant_nested_type=True)
pq_reader = pa.BufferReader(writer.getvalue())
load_job = bq_client.load_table_from_file(
pq_reader, f'{DATASET}.{TABLE}', job_config=job_config)
load_job.add_done_callback(upload_complete)
return load_job
def bq_writer(writer_id):
"""Primary logice for a BigQuery writer thread"""
global done_feeding, job_config
jobs = []
print(f"w({writer_id}): writer starting")
while True:
try:
batch = q.get(timeout=5)
if len(batch) < 1:
break
table = pa.Table.from_batches(batch, batch[0].schema)
load_job = write_to_bigquery(writer_id, table, job_config)
jobs.append(load_job)
except queue.Empty:
# use this as a chance to cleanup our jobs
_jobs = []
for j in jobs:
if j.running():
_jobs.append(j)
elif j.error_result:
# consider this fatal for now...might have hit a rate-limit!
print(f"w({writer_id}): job {j} had an error {j.error_result}!!!")
sys.exit(1)
if len(_jobs) > 0:
print(f"w({writer_id}): waiting on {len(jobs)} bq load jobs")
elif done_feeding.is_set():
break
print(f"w({writer_id}): finished")
def stream_records(reader):
"""Consume a neo4j-arrow GDS stream and populate a work queue"""
print('Start arrow table processing')
cnt, rows, nbytes = 0, 0, 0
batch = []
start = time()
for chunk, metadata, in reader:
cnt = cnt + chunk.num_rows
rows = rows + chunk.num_rows
nbytes = nbytes + chunk.nbytes
batch.append(chunk)
if rows >= 100_000:
q.put(batch)
nbytes = (nbytes >> 20)
print(f"stream row @ {cnt:,}, batch size: {rows:,} rows, {nbytes:,} MiB")
batch = []
rows, nbytes = 0, 0
if len(batch) > 0:
# add any remaining data
q.put(batch)
# signal we're done consuming the source feed and wait for work to complete
done_feeding.set()
q.join()
finish = time()
print(f"Done! Time Delta: {round(finish - start, 1):,}s")
print(f"Count: {cnt:,} rows, Rate: {round(cnt / (finish - start)):,} rows/s")
if __name__ == "__main__":
print("Creating neo4j-arrow client")
client = na.Neo4jArrow(USERNAME, PASSWORD, (HOST, PORT),
tls=TLS, verifyTls=TLS_VERIFY)
print("Submitting read job for graph '{GRAPH}'")
ticket = client.gds_nodes(GRAPH, properties=PROPERTIES)
print("Starting worker threads")
threads = []
for i in range(0, 12):
t = threading.Thread(target=bq_writer, daemon=True, args=[i])
threads.append(t)
t.start()
print(f"Streaming nodes from {GRAPH} with properties {PROPERTIES}")
client.wait_for_job(ticket, timeout=180)
reader = client.stream(ticket)
stream_records(reader)
# try to nicely wait for threads to finish just in case
for t in threads:
t.join(timeout=60)