Skip to content

Commit

Permalink
added support for tuple, lists and dicts (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
MartBakler authored Nov 8, 2023
1 parent b8c0c08 commit e75a03c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
6 changes: 6 additions & 0 deletions src/monkey_patch/assertion_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ def eval_expr(self, node):
return node.n # Assume it's a number for simplicity
elif isinstance(node, ast.Str):
return node.s # Assume it's a string for simplicity
elif isinstance(node, ast.List):
return [self.eval_expr(x) for x in node.elts] # Assume it's a list for simplicity
elif isinstance(node, ast.Tuple):
return tuple([self.eval_expr(x) for x in node.elts]) # Assume it's a tuple for simplicity
elif isinstance(node, ast.Dict):
return {self.eval_expr(k): self.eval_expr(v) for k, v in zip(node.keys, node.values)}
else:
raise NotImplementedError(f"Node type {type(node).__name__} not handled yet")

Expand Down
2 changes: 1 addition & 1 deletion src/monkey_patch/language_models/language_modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_generation_case(self, args, kwargs, function_modeler, function_descripti
return prompt, distilled_model, suitable_for_distillation, True

else:
aligns = function_modeler.get_alignments(function_description.__hash__(), max=5)
aligns = function_modeler.get_alignments(function_description.__hash__(), max=16)
examples = "\n".join([f"Inputs:\nArgs: {align['args']}\nKwargs: {align['kwargs']}\nOutput: {align['output']}" for align in aligns])
prompt = self.construct_prompt(f, args, kwargs, examples)
examples_token_count = approximate_token_count(examples)
Expand Down
2 changes: 1 addition & 1 deletion src/monkey_patch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _deep_tuple(obj):
:param obj:
:return:
"""
if isinstance(obj, list):
if isinstance(obj, list) or isinstance(obj, tuple):
return tuple(_deep_tuple(e) for e in obj)
elif isinstance(obj, dict):
return tuple((k, _deep_tuple(v)) for k, v in sorted(obj.items()))
Expand Down

0 comments on commit e75a03c

Please sign in to comment.