Skip to content

Commit

Permalink
more scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Dec 6, 2024
1 parent 4341fe1 commit f8ef542
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 3 deletions.
6 changes: 3 additions & 3 deletions json_stats/gennegative.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def gen_example(schema, instance, info):
+ json.dumps(schema, indent=4)
+ "\n\nHere is the valid instance:\n"
+ json.dumps(instance, indent=4)
+ "\n\nPlease modify the instance to make it invalid according to the schema. Focus on corner cases."
+ "\n\nPlease modify the instance to make it invalid according to the schema. Focus on corner cases.\n"
+ info
)
if len(prompt) > 100_000:
if len(prompt) > 200_000:
return {"error": f"Prompt too long, {len(prompt)}"}
req = {
"model": "model",
Expand All @@ -90,7 +90,7 @@ def gen_example(schema, instance, info):
"schema": {},
},
},
"max_tokens": 1000,
"max_tokens": 8000,
"temperature": 0.2,
}

Expand Down
122 changes: 122 additions & 0 deletions json_stats/process_negatives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env python3

import sys
import json
import glob
import os
import random
import re
from jsonschema import Draft202012Validator, validate

prev_base = os.environ.get("HOME") + "/src/json-data/prev_tests"
output_base = os.environ.get("HOME") + "/src/json-data/new_tests"

class Stats:
def __init__(self):
self.files = 0
self.responses = 0
self.server_error = 0
self.json_error = 0
self.json_error_non_length = 0
self.validation_error = 0
self.invalidation_error = 0
self.not_negative = 0
self.negative_added = 0


stats = Stats()


def process_file(file_name):
file_base = file_name.split("/")[-1]

stats.files += 1

with open(file_name) as f:
inp = json.loads(f.read())

pos_data = inp["pos_data"]

prev_file = f"{prev_base}/{file_base}"
if os.path.exists(prev_file):
with open(f"{prev_base}/{file_base}") as f:
pos_data = json.loads(f.read())

schema = pos_data["schema"]
tests = pos_data["tests"]

Draft202012Validator.check_schema(schema)

for idx, test in enumerate( tests ):
try:
validate(test["data"], schema, format_checker=Draft202012Validator.FORMAT_CHECKER)
if not test["valid"]:
print("positive already there", file_name, idx)
stats.invalidation_error += 1
except Exception as e:
if test["valid"]:
stats.validation_error += 1
print("validation error", file_name, idx, repr(e))

for idx, resp in enumerate(inp["responses"]):
stats.responses += 1

if resp.get("error", None):
stats.server_error += 1
continue

rs = resp["choices"][0]["message"]["content"]
try:
r = json.loads(rs)
except:
stats.json_error += 1
if resp["choices"][0]["finish_reason"] != "length":
stats.json_error_non_length += 1
if resp["choices"][0].get("llg_logs", None):
print("non-length-llg", file_name, idx)
else:
print("non-length", file_name, idx)
continue

try:
validate(r, schema, format_checker=Draft202012Validator.FORMAT_CHECKER)
# print("not negative", file_name, idx)
stats.not_negative += 1
continue
except Exception as e:
# good
pass

stats.negative_added += 1

# f"violate a constraint introduced by {f} in the schema"
prompt = resp["expanded_prompt"]
description = "llama-70b generated negative"
m = re.search(r"violate a constraint introduced by (.+?) in the schema", prompt)
if m:
description += "; focus on " + m.group(1)

tests.append({
"description": description,
"valid": False,
"data": r,
})

with open(f"{output_base}/{file_base}", "w") as f:
f.write(json.dumps(pos_data, indent=4))


files = []
for arg in sys.argv[1:]:
if arg.endswith(".json"):
files.append(arg)
else:
files.extend(glob.glob(arg + "/*.json"))
print(len(files))

for idx, f in enumerate( files ):
if idx % 500 == 0:
print(idx, stats.__dict__)
process_file(f)

print(stats.__dict__)
76 changes: 76 additions & 0 deletions json_stats/process_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python3

import sys
import json
import glob
import os
import random
import re
from jsonschema import Draft202012Validator, validate

output_base = os.environ.get("HOME") + "/src/json-data/unique_tests"


class Stats:
def __init__(self):
self.files = 0
self.duplicates = 0
self.pos_tests = 0
self.neg_tests = 0
self.files_without_neg = 0


stats = Stats()


def process_file(file_name):
file_base = file_name.split("/")[-1]

stats.files += 1

with open(file_name) as f:
inp = json.loads(f.read())

pos_data = inp
schema = pos_data["schema"]
tests = pos_data["tests"]

existing_tests = set()
tests_copy = []

num_neg = 0
for t in tests:
key = json.dumps(t["data"])
if key in existing_tests:
stats.duplicates += 1
continue
existing_tests.add(key)
tests_copy.append(t)
if t["valid"]:
stats.pos_tests += 1
else:
stats.neg_tests += 1
num_neg += 1

if num_neg == 0:
stats.files_without_neg += 1

pos_data["tests"] = tests_copy
with open(f"{output_base}/{file_base}", "w") as f:
f.write(json.dumps(pos_data, indent=2))


files = []
for arg in sys.argv[1:]:
if arg.endswith(".json"):
files.append(arg)
else:
files.extend(glob.glob(arg + "/*.json"))
print(len(files))

for idx, f in enumerate(files):
if idx % 500 == 0:
print(idx, stats.__dict__)
process_file(f)

print(stats.__dict__)
63 changes: 63 additions & 0 deletions json_stats/token_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python3

import sys
import json
import glob
import os
import random
import re
from jsonschema import Draft202012Validator, validate

class Stats:
def __init__(self):
self.files = 0
self.responses = 0
self.server_error = 0
self.prompt_tokens = 0
self.completion_tokens = 0


stats = Stats()


def process_file(file_name):
file_base = file_name.split("/")[-1]

stats.files += 1

with open(file_name) as f:
inp = json.loads(f.read())

pos_data = inp["pos_data"]

for idx, resp in enumerate(inp["responses"]):
stats.responses += 1

if resp.get("error", None):
stats.server_error += 1
print("server error", file_name, idx)
continue

if resp["choices"][0]["finish_reason"] == "length":
print("length", file_name, idx)
continue

usage = resp["usage"]
stats.completion_tokens += usage["completion_tokens"]
stats.prompt_tokens += usage["prompt_tokens"]


files = []
for arg in sys.argv[1:]:
if arg.endswith(".json"):
files.append(arg)
else:
files.extend(glob.glob(arg + "/*.json"))
print(len(files))

for idx, f in enumerate( files ):
if idx % 1000 == 0:
print(idx, stats.__dict__)
process_file(f)

print(stats.__dict__)

0 comments on commit f8ef542

Please sign in to comment.