Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚀[FEA]: Torch Jit for SympyToTorch #153

Open
bridgesign opened this issue Jun 6, 2024 · 0 comments
Open

🚀[FEA]: Torch Jit for SympyToTorch #153

bridgesign opened this issue Jun 6, 2024 · 0 comments
Labels
? - Needs Triage Need team to review and classify enhancement New feature or request

Comments

@bridgesign
Copy link

bridgesign commented Jun 6, 2024

Is this a new feature, an improvement, or a change to existing functionality?

Improvement

How would you describe the priority of this feature request

Medium

Please provide a clear description of problem you would like to solve.

I was trying to get some sympy equations to compile with torch jit, and found the issues in the file associated with the torch printer.

While going through how sympy creates lambda functions I was able to hack up the following solution to allow torch jit.

def sympy_torch_script(
    expr: sympy.Expr,
    keys: List[str],
    extra_funcs: Optional[Dict] = None,
) -> Callable:
    torch_expr = torch_lambdify(expr, keys, extra_funcs=extra_funcs)
    torch_expr.__module__ = "torch"
    filename = '<wrapped>-%s' % torch_expr.__code__.co_filename
    funclocals = {}
    namespace = {"Dict": Dict, "torch": torch, "func": torch_expr}
    funcname = "_wrapped"
    code = 'def %s(vars: Dict[str, torch.Tensor]) -> torch.Tensor:\n' % funcname
    code += '    return func('
    for key in keys:
        code += 'vars["%s"],' % key
    code += ')\n'
    c = compile(code, filename, 'exec')
    exec(c, namespace, funclocals)
    linecache.cache[filename] = (
        len(code),
        None,
        code.splitlines(keepends=True),
        filename,
    )
    func = funclocals[funcname]
    func.__module__ = "torch"
    return func

The code above creates a function that takes a dictionary of tensors, picks the relevant arguments and then runs the lambdified function. I had to make some minor changes with torch_lamdify but it works for both cpu and cuda. There are some issues when the output is a constant as jit outputs it as an int ot float.

Seems like a good idea?

Describe any alternatives you have considered

No response

Additional context

No response

Tasks

No tasks being tracked yet.
@bridgesign bridgesign added ? - Needs Triage Need team to review and classify enhancement New feature or request labels Jun 6, 2024
@bridgesign bridgesign changed the title Torch Jit for SympyToTorch🚀[FEA]: 🚀[FEA]: Torch Jit for SympyToTorch Jun 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage Need team to review and classify enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant