Skip to content

Commit

Permalink
feat(compiler,register): now matching on containing class and docstring
Browse files Browse the repository at this point in the history
name cannot be retrieved reliably at runtime after compiling to javascript. now matches on docstring
and parentName
  • Loading branch information
Jack Hopkins committed Feb 1, 2024
1 parent cb9ac48 commit f47b23f
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 22 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ export AWS_SECRET_ACCESS_KEY=...
export AWS_ACCESS_KEY_ID=...
```

Next, we need to install the Tanuki type transformer. This will allow Tanuki to be aware of your patched functions and types at runtime, as these types are usually erased by the Typescript compiler when transpiling into Javascript.
```typescript
npm install ts-patch --save-dev
npx ts-patch install
```

Next, you need to add the Tanuki transformer to your `tsconfig.json` file:

```json
Expand Down
10 changes: 10 additions & 0 deletions src/models/functionDescription.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import { JSONSchema } from './jsonSchema';
class FunctionDescription {
name: string;
docstring: string;
parentName?: string;
sourceFile?: string;
inputTypeDefinition?: string
inputTypeSchema?: JSONSchema;
outputTypeDefinition?: string;
Expand All @@ -14,6 +16,8 @@ class FunctionDescription {
constructor(
name: string,
docstring: string,
parentName?: string,
sourceFile?: string,
inputTypeDefinition?: string,
outputTypeDefinition?: string,
inputTypeSchema?: JSONSchema,
Expand All @@ -22,6 +26,12 @@ class FunctionDescription {
) {
this.name = name;
this.docstring = docstring;
if (parentName != null) {
this.parentName = parentName;
}
if (sourceFile != null) {
this.sourceFile = sourceFile;
}
if (inputTypeDefinition != null) {
this.inputTypeDefinition = inputTypeDefinition;
}
Expand Down
53 changes: 51 additions & 2 deletions src/register.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import { REGISTERED_FUNCTIONS_FILENAME } from "./constants";
interface FunctionDescriptionJSON {
name: string;
docstring: string;
parentName?: string;
sourceFile?: string;
inputTypeDefinition: string;
outputTypeDefinition: string;
inputTypeSchema: JSONSchema;
Expand Down Expand Up @@ -79,6 +81,8 @@ export class Register {
const pf = new FunctionDescription(
pfj.name,
pfj.docstring,
pfj.parentName,
pfj.sourceFile,
undefined,
undefined,
pfj.inputTypeSchema,
Expand Down Expand Up @@ -165,8 +169,53 @@ export class Register {
return functionDescription;
throw new Error("Method not implemented.");
}*/
static getNamedFunctions(): string[] {
return [...Object.keys(this.alignableSymbolicFunctions), ...Object.keys(this.alignableEmbeddingFunctions)];
// @ts-ignore
static getNamedFunctions(classContext, docstring: string): FunctionDescription {
const className = classContext.name;
const classPrefix = className + '.';

// Gather all function names from alignable functions
const allFunctionNames = [
...Object.keys(this.alignableSymbolicFunctions),
...Object.keys(this.alignableEmbeddingFunctions)
];

const filterFunctions = (functions: { // @ts-ignore
[p: string]: FunctionDescription }) => {
return Object.values(functions)
.filter(funcDesc => funcDesc.parentName === classContext.name)
.filter(funcDesc => docstring === "" || funcDesc.docstring === docstring)
.map(funcDesc => funcDesc);
};

// Apply the filter to both symbolic and embedding functions
const symbolicFunctionNames = filterFunctions(this.alignableSymbolicFunctions);
const embeddingFunctionNames = filterFunctions(this.alignableEmbeddingFunctions);
const allFunctions = [...symbolicFunctionNames, ...embeddingFunctionNames];

// If more than one function is found, throw an error
if (allFunctions.length > 1) {
throw new Error(`Multiple functions with name "${className}" and docstring "${docstring}" found.`);
}
// If no function is found, throw an error
if (allFunctions.length === 0) {
throw new Error(`Function with name "${className}" and docstring "${docstring}" not found in class "${classContext.name}". Check source file: "${classContext.sourceFile}"`);
}
return allFunctions[0]

// Filter by the members of the classContext
// const memberFunctionNames = allFunctionNames.filter(name =>
// typeof classContext[name] === 'function'
// );
// const symbolicFunctionNames = Object.keys(this.alignableSymbolicFunctions)
// .filter(key => key.startsWith(classPrefix))
// .map(key => key.slice(classPrefix.length)); // Remove the prefix
//
// const embeddingFunctionNames = Object.keys(this.alignableEmbeddingFunctions)
// .filter(key => key.startsWith(classPrefix))
// .map(key => key.slice(classPrefix.length)); // Remove the prefix
//
// return [...symbolicFunctionNames, ...embeddingFunctionNames];
}
static loadFunctionDescription(functionName: string, docString: string): FunctionDescription {
// Iterate over alignableSymbolicFunctions
Expand Down
9 changes: 6 additions & 3 deletions src/tanuki.ts
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,12 @@ export function patch<OutputType, InputType>(config?: PatchConfig) {
}

// Return a function that takes an input of type InputType and returns a value of type OutputType
return async (input: InputType): Promise<OutputType> => {
const functionName: string = getCallerInfo(Register.getNamedFunctions());
const functionDescription: FunctionDescription = Register.loadFunctionDescription(functionName, docstring);
return async function(this: any, input: InputType): Promise<OutputType> {
const parentClass = this;
const functionDescription = Register.getNamedFunctions(parentClass, docstring);
//const functionName: string = getCallerInfo(namedFunctions);
//const functionName: string = foundFunction.name
//const functionDescription: FunctionDescription = Register.loadFunctionDescription(functionName, docstring);
let embeddingCase = false;
if (config) {
FunctionModeler.setConfig(functionDescription, config);
Expand Down
48 changes: 31 additions & 17 deletions src/tanukiTransformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ export interface JSONSchema {
[key: string]: any;
}

class FunctionDescription {
class CompiledFunctionDescription {
name: string;
docstring: string;
parentName?: string;
sourceFile?: string;
inputTypeDefinition?: string
inputTypeSchema?: JSONSchema;
outputTypeDefinition?: string;
Expand All @@ -58,6 +60,8 @@ class FunctionDescription {
constructor(
name: string,
docstring: string,
parentName?: string,
sourceFile?: string,
inputTypeDefinition?: string,
outputTypeDefinition?: string,
inputTypeSchema?: JSONSchema,
Expand All @@ -66,6 +70,12 @@ class FunctionDescription {
) {
this.name = name;
this.docstring = docstring;
if (parentName != undefined) {
this.parentName = parentName;
}
if (sourceFile != undefined) {
this.sourceFile = sourceFile;
}
if (inputTypeDefinition != null) {
this.inputTypeDefinition = inputTypeDefinition;
}
Expand Down Expand Up @@ -124,7 +134,7 @@ export class PatchFunctionCompiler {
return
}
console.debug("Compiling " + file.fileName)
const patchFunctions: FunctionDescription[] = [];
const patchFunctions: CompiledFunctionDescription[] = [];

// First, populate type definitions
this.ts.forEachChild(file, node => this.extractTypeDefinitions(node));
Expand Down Expand Up @@ -154,10 +164,10 @@ export class PatchFunctionCompiler {
pf.name,
inputTypeDefinitionTokenStream
);
const name = pf.name
/*const name = pf.name
if (this.compiledFunctionNames.indexOf(name) > -1) {
throw new Error("Function name collision in `"+file.fileName+".\nPlease move `"+name+"` into its own namespace.")
}
throw new Error("Function name collision in `"+file.fileName+".\nPlease move `"+name+"` into its own namespace. Tanuki functions have to be unique across all files. ")
}*/

this.compiledFunctionNames.push(pf.name)
});
Expand Down Expand Up @@ -290,7 +300,7 @@ export class PatchFunctionCompiler {
}
this.ts.forEachChild(node, child => this.extractTypeDefinitions(child));
}
visit(node: ts.Node, patchFunctions: FunctionDescription[], file: ts.SourceFile): void {
visit(node: ts.Node, patchFunctions: CompiledFunctionDescription[], file: ts.SourceFile): void {
if (this.ts.isClassDeclaration(node) || this.ts.isModuleDeclaration(node)) {
const previousClassOrModule = this.currentClassOrModule;
this.currentClassOrModule = node;
Expand Down Expand Up @@ -337,7 +347,7 @@ export class PatchFunctionCompiler {
functionName: string,
file: ts.SourceFile,
currentClassOrModule: ts.ClassDeclaration | ts.ModuleDeclaration | null
): FunctionDescription | null {
): CompiledFunctionDescription | null {
if (
this.ts.isPropertyDeclaration(node) &&
node.initializer &&
Expand All @@ -347,7 +357,7 @@ export class PatchFunctionCompiler {
let parent = '';
if (node.parent && node.parent.name && this.ts.isClassDeclaration(node.parent)) {
if (node.parent.name.getText() === 'Function') {
throw new Error("Tanuki functions cannot live in an class called `Function`");
throw new Error("The class `Function` cannot have patched functions as members, as this is a reserved word. You could rename the class.")
}
// @ts-ignore
parent = node.parent.name.getText() + ".";
Expand All @@ -372,7 +382,7 @@ export class PatchFunctionCompiler {
}
const staticFlag = isNodeStatic(node);

let name = staticFlag ? functionName : parent+functionName; // Use the passed function name
let name = functionName; // Use the passed function name

const docstringWithTicks = node.initializer.template.getText();
const docstring = docstringWithTicks.replace(/`/g, '');
Expand All @@ -384,20 +394,24 @@ export class PatchFunctionCompiler {
const outputTypeNode: ts.Node = typeArguments[0];
const inputTypeNode: ts.Node = typeArguments[1];

const inputType = inputTypeNode.getText(); // Get the textual representation of the input type
const outputType = outputTypeNode.getText(); // Get the textual representation of the output type
const inputType = inputTypeNode.getText(file); // Get the textual representation of the input type
const outputType = outputTypeNode.getText(file); // Get the textual representation of the output type

const current = this.currentClassOrModule
const inputTypeDefinition = this.extractTypeDefinition(inputType, current);
const outputTypeDefinition = this.extractTypeDefinition(outputType, current);
const inputTypeDefinition = this.extractTypeDefinition(inputType, current, file);
const outputTypeDefinition = this.extractTypeDefinition(outputType, current, file);

const type = !outputTypeDefinition.startsWith('Embedding')
? FunctionType.SYMBOLIC
: FunctionType.EMBEDDABLE;
//name = current?.name?.getText() + "." + name
return new FunctionDescription(
// @ts-ignore
let parentName = current.name.getText()
return new CompiledFunctionDescription(
name,
docstring,
parentName,
file.fileName,
inputTypeDefinition,
outputTypeDefinition,
undefined,
Expand All @@ -415,7 +429,7 @@ export class PatchFunctionCompiler {

return null;
}
extractTypeDefinition(type: string, currentScope: ts.Node | null): string {
extractTypeDefinition(type: string, currentScope: ts.Node | null, file: SourceFile): string {

// If the type is a primitive type, return it
const primitiveTypes = ['number', 'string', 'boolean', 'null'];
Expand All @@ -438,7 +452,7 @@ export class PatchFunctionCompiler {
return definition;
}
if (currentScope) {
definition = this.findAndResolveType(type, currentScope.getSourceFile())
definition = this.findAndResolveType(type, file)
if (definition) {
return definition;
}
Expand Down Expand Up @@ -943,7 +957,7 @@ export class PatchFunctionCompiler {
return hasPatch;
}

writeToJSON(patchFunctions: FunctionDescription[]): void {
writeToJSON(patchFunctions: CompiledFunctionDescription[]): void {
// Determine the output directory
const distDirectory = PatchFunctionCompiler.getDistDirectory();

Expand Down

0 comments on commit f47b23f

Please sign in to comment.