Skip to content

Commit

Permalink
support Bedrock
Browse files Browse the repository at this point in the history
  • Loading branch information
kevintruong committed Nov 21, 2023
1 parent 03c5fa7 commit 4422bae
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 127 deletions.
2 changes: 1 addition & 1 deletion .idea/monkeyFunctions.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

121 changes: 59 additions & 62 deletions src/monkey_patch/function_modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,24 @@
from monkey_patch.utils import approximate_token_count, prepare_object_for_saving, encode_int, decode_int
import copy


EXAMPLE_ELEMENT_LIMIT = 1000


class FunctionModeler(object):
def __init__(self, data_worker, workspace_id = 0, check_for_finetunes = True) -> None:
def __init__(self, data_worker, workspace_id=0, check_for_finetunes=True) -> None:
self.function_configs = {}
self.data_worker = data_worker
self.distillation_token_limit = 3000 # the token limit for finetuning
self.distillation_token_limit = 3000 # the token limit for finetuning
self.align_buffer = {}
self._get_datasets()
self.workspace_id = workspace_id
self.check_for_finetunes = check_for_finetunes


def _get_dataset_info(self, dataset_type, func_hash, type = "length"):
def _get_dataset_info(self, dataset_type, func_hash, type="length"):
"""
Get the dataset size for a function hash
"""
return self.data_worker._load_dataset(dataset_type, func_hash, return_type = type)
return self.data_worker._load_dataset(dataset_type, func_hash, return_type=type)

def _get_datasets(self):
"""
Expand All @@ -48,17 +46,16 @@ def save_align_statements(self, function_hash, args, kwargs, output):
successfully_saved, new_datapoint = self.data_worker.log_align(function_hash, example)
if successfully_saved:
if function_hash in self.dataset_sizes["alignments"]:
self.dataset_sizes["alignments"][function_hash] += 1
self.dataset_sizes["alignments"][function_hash] += 1
else:
self.dataset_sizes["alignments"][function_hash] = 1

if new_datapoint:
# update align buffer
if function_hash not in self.align_buffer:
self.align_buffer[function_hash] = bytearray()
self.align_buffer[function_hash].extend(str(example.__dict__).encode('utf-8') + b'\r\n')


def save_datapoint(self, func_hash, example):
"""
Save datapoint to the training data
Expand All @@ -68,13 +65,14 @@ def save_datapoint(self, func_hash, example):
if func_hash in self.dataset_sizes["patches"]:
# if the dataset size is -1, it means we havent read in the dataset size yet
if self.dataset_sizes["patches"][func_hash] == -1:
self.dataset_sizes["patches"][func_hash] = self._get_dataset_info("patches", func_hash, type = "length")
self.dataset_sizes["patches"][func_hash] = self._get_dataset_info("patches", func_hash,
type="length")
else:
self.dataset_sizes["patches"][func_hash] += datapoints
else:
self.dataset_sizes["patches"][func_hash] = datapoints
return len(written_datapoints) > 0

def get_alignments(self, func_hash, max=20):
"""
Get all aligns for a function hash
Expand All @@ -97,7 +95,7 @@ def get_alignments(self, func_hash, max=20):
# easy and straightforward way to get nr of words (not perfect but doesnt need to be)
# Can do the proper way of tokenizing later, it might be slower and we dont need 100% accuracy
example_element_limit = EXAMPLE_ELEMENT_LIMIT

examples = []
for example_bytes in split_buffer:
if example_bytes in example_set:
Expand All @@ -121,18 +119,17 @@ def load_align_statements(self, function_hash):
Load all align statements
"""
if function_hash not in self.align_buffer:
dataset_size, align_dataset = self._get_dataset_info("alignments", function_hash, type = "both")
dataset_size, align_dataset = self._get_dataset_info("alignments", function_hash, type="both")
if align_dataset:
self.align_buffer[function_hash] = bytearray(align_dataset)
self.dataset_sizes["alignments"][function_hash] = dataset_size


def postprocess_datapoint(self, func_hash, function_description, example, repaired=True):
"""
Postprocess the datapoint
"""
try:

added = self.save_datapoint(func_hash, example)
if added:
self._update_datapoint_config(repaired, func_hash)
Expand All @@ -147,21 +144,20 @@ def _load_function_config(self, func_hash, function_description):
"""
Load the config file for a function hash
"""

config, default = self.data_worker._load_function_config(func_hash)
if default and self.check_for_finetunes:
finetuned, finetune_config = self._check_for_finetunes(function_description)
if finetuned:
config = finetune_config
self.function_configs[func_hash] = config
return config



def _check_for_finetunes(self, function_description):
# This here should be discussed, what's the bestd way to do it

# hash the function_hash into 16 characters
finetune_hash = function_description.__hash__(purpose = "finetune") + encode_int(self.workspace_id)
finetune_hash = function_description.__hash__(purpose="finetune") + encode_int(self.workspace_id)
# List 10 fine-tuning jobs
finetunes = openai.FineTuningJob.list(limit=1000)
# Check if the function_hash is in the fine-tuning jobs
Expand All @@ -177,9 +173,9 @@ def _check_for_finetunes(self, function_description):
return True, config
except:
return False, {}

return False, {}

def _construct_config_from_finetune(self, finetune_hash, finetune):
model = finetune["fine_tuned_model"]
# get the ending location of finetune hash in the model name
Expand All @@ -190,19 +186,16 @@ def _construct_config_from_finetune(self, finetune_hash, finetune):
nr_of_training_runs = decode_int(next_char) + 1
nr_of_training_points = (2 ** nr_of_training_runs) * 200
config = {
"distilled_model": model,
"current_model_stats": {
"trained_on_datapoints": nr_of_training_points,
"running_faults": []},
"last_training_run": {"trained_on_datapoints": nr_of_training_points},
"current_training_run": {},
"teacher_models": ["gpt-4","gpt-4-32k"], # currently supported teacher models
"nr_of_training_runs": nr_of_training_runs}

return config


"distilled_model": model,
"current_model_stats": {
"trained_on_datapoints": nr_of_training_points,
"running_faults": []},
"last_training_run": {"trained_on_datapoints": nr_of_training_points},
"current_training_run": {},
"teacher_models": ["gpt-4", "gpt-4-32k"], # currently supported teacher models
"nr_of_training_runs": nr_of_training_runs}

return config

def get_models(self, function_description):
"""
Expand All @@ -213,7 +206,7 @@ def get_models(self, function_description):
func_config = self.function_configs[func_hash]
else:
func_config = self._load_function_config(func_hash, function_description)

# for backwards compatibility
if "distilled_model" not in func_config:
if func_config["current_model"] in func_config["teacher_models"]:
Expand All @@ -224,7 +217,7 @@ def get_models(self, function_description):
distilled_model = func_config["distilled_model"]

return distilled_model, func_config["teacher_models"]

def _update_datapoint_config(self, repaired, func_hash):
"""
Update the config to reflect the new datapoint in the training data
Expand All @@ -242,7 +235,7 @@ def _update_datapoint_config(self, repaired, func_hash):
self.function_configs[func_hash]["current_model_stats"]["running_faults"].append(0)
# take the last 100 datapoints
self.function_configs[func_hash]["current_model_stats"]["running_faults"] = \
self.function_configs[func_hash]["current_model_stats"]["running_faults"][-100:]
self.function_configs[func_hash]["current_model_stats"]["running_faults"][-100:]

# check if the last 10 datapoints are 50% faulty, this is the switch condition
if sum(self.function_configs[func_hash]["current_model_stats"]["running_faults"][-10:]) / 10 > 0.5:
Expand All @@ -255,11 +248,10 @@ def _update_datapoint_config(self, repaired, func_hash):
print(e)
print("Could not update config file")
pass

def _update_config_file(self, func_hash):
self.data_worker._update_function_config(func_hash, self.function_configs[func_hash])


def check_for_finetuning(self, function_description, func_hash):
"""
Check for finetuning status
Expand All @@ -278,7 +270,7 @@ def check_for_finetuning(self, function_description, func_hash):
except Exception as e:
print(e)
print("Error checking for finetuning")

def _check_finetuning_condition(self, func_hash):
"""
Check if the finetuning condition is met
Expand All @@ -287,18 +279,19 @@ def _check_finetuning_condition(self, func_hash):
if func_hash not in self.function_configs:
return False


training_threshold = (2 ** self.function_configs[func_hash]["nr_of_training_runs"]) * 200

align_dataset_size = self.dataset_sizes["alignments"][func_hash] if func_hash in self.dataset_sizes["alignments"] else 0
patch_dataset_size = self.dataset_sizes["patches"][func_hash] if func_hash in self.dataset_sizes["patches"] else 0
align_dataset_size = self.dataset_sizes["alignments"][func_hash] if func_hash in self.dataset_sizes[
"alignments"] else 0
patch_dataset_size = self.dataset_sizes["patches"][func_hash] if func_hash in self.dataset_sizes[
"patches"] else 0

if patch_dataset_size == -1:
# if havent read in the patch dataset size, read it in
patch_dataset_size = self._get_dataset_info("patches", func_hash, type = "length")
patch_dataset_size = self._get_dataset_info("patches", func_hash, type="length")
self.dataset_sizes["patches"][func_hash] = patch_dataset_size
return (patch_dataset_size + align_dataset_size) > training_threshold

def _execute_finetuning(self, function_description, func_hash):
"""
Execute the finetuning
Expand All @@ -308,24 +301,24 @@ def _execute_finetuning(self, function_description, func_hash):
"""
# get function description
function_string = str(function_description.__dict__.__repr__() + "\n")

# get the align dataset
align_dataset = self._get_dataset_info("alignments", func_hash, type = "dataset")
align_dataset = self._get_dataset_info("alignments", func_hash, type="dataset")
if not align_dataset:
align_dataset = ""
else:
align_dataset = align_dataset.decode('utf-8')

# get the patch dataset
patch_dataset = self._get_dataset_info("patches", func_hash, type = "dataset")
patch_dataset = self._get_dataset_info("patches", func_hash, type="dataset")
if not patch_dataset:
patch_dataset = ""
else:
patch_dataset = patch_dataset.decode('utf-8')

if align_dataset == "" and patch_dataset == "":
return

dataset = align_dataset + patch_dataset

dataset.replace("\\n", "[SEP_TOKEN]")
Expand All @@ -345,7 +338,7 @@ def _execute_finetuning(self, function_description, func_hash):
"content": f"{instruction}\nFunction: {function_string}---\nInputs:\nArgs: {x['args']}\nKwargs: {x['kwargs']}\nOutput:"},
{"role": "assistant", "content": str(x['output']) if x['output'] is not None else "None"}]}
for x in dataset]

# Create an in-memory text stream
temp_file = io.StringIO()
# Write data to the stream
Expand All @@ -358,7 +351,7 @@ def _execute_finetuning(self, function_description, func_hash):
temp_file.seek(0)

# create the finetune hash
finetune_hash = function_description.__hash__(purpose = "finetune")
finetune_hash = function_description.__hash__(purpose="finetune")
nr_of_training_runs = self.function_configs[func_hash]["nr_of_training_runs"]
finetune_hash += encode_int(self.workspace_id)
finetune_hash += encode_int(nr_of_training_runs)
Expand All @@ -370,21 +363,23 @@ def _execute_finetuning(self, function_description, func_hash):
return

# here can be sure that datasets were read in as that is checked in the finetune_check
align_dataset_size = self.dataset_sizes["alignments"][func_hash] if func_hash in self.dataset_sizes["alignments"] else 0
patch_dataset_size = self.dataset_sizes["patches"][func_hash] if func_hash in self.dataset_sizes["patches"] else 0
align_dataset_size = self.dataset_sizes["alignments"][func_hash] if func_hash in self.dataset_sizes[
"alignments"] else 0
patch_dataset_size = self.dataset_sizes["patches"][func_hash] if func_hash in self.dataset_sizes[
"patches"] else 0
total_dataset_size = align_dataset_size + patch_dataset_size
training_file_id = response["id"]
# submit the finetuning job
try:
finetuning_response = openai.FineTuningJob.create(training_file=training_file_id, model="gpt-3.5-turbo",
suffix=finetune_hash)
suffix=finetune_hash)
except Exception as e:
return

self.function_configs[func_hash]["current_training_run"] = {"job_id": finetuning_response["id"],
"trained_on_datapoints": total_dataset_size,
"last_checked": datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S")}
"trained_on_datapoints": total_dataset_size,
"last_checked": datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S")}
# update the config json file
try:
self._update_config_file(func_hash)
Expand Down Expand Up @@ -417,11 +412,13 @@ def _update_finetune_config(self, response, func_hash, status):
"""
if status == "failed":
self.function_configs[func_hash]["current_training_run"] = {}
else:
else:
self.function_configs[func_hash]["distilled_model"] = response["fine_tuned_model"]
self.function_configs[func_hash]["last_training_run"] = self.function_configs[func_hash]["current_training_run"]
self.function_configs[func_hash]["last_training_run"] = self.function_configs[func_hash][
"current_training_run"]
self.function_configs[func_hash]["current_model_stats"] = {
"trained_on_datapoints": self.function_configs[func_hash]["current_training_run"]["trained_on_datapoints"],
"trained_on_datapoints": self.function_configs[func_hash]["current_training_run"][
"trained_on_datapoints"],
"running_faults": []}
self.function_configs[func_hash]["nr_of_training_runs"] += 1
self.function_configs[func_hash]["current_training_run"] = {}
Expand All @@ -430,4 +427,4 @@ def _update_finetune_config(self, response, func_hash, status):
except Exception as e:
print(e)
print("Could not update config file after a successful finetuning run")
pass
pass
Loading

0 comments on commit 4422bae

Please sign in to comment.