Skip to content

Commit

Permalink
Merge pull request #255 from bfredl/threading
Browse files Browse the repository at this point in the history
catch requests from invalid thread
  • Loading branch information
bfredl authored Nov 2, 2018
2 parents 0c5257a + d2bf46f commit 0852da8
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 0 deletions.
20 changes: 20 additions & 0 deletions neovim/api/nvim.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Main Nvim interface."""
import os
import sys
import threading
from functools import partial
from traceback import format_stack

Expand Down Expand Up @@ -165,6 +166,15 @@ def request(self, name, *args, **kwargs):
present and True, a asynchronous notification is sent instead. This
will never block, and the return value or error is ignored.
"""
if (self._session._loop_thread is not None and
threading.current_thread() != self._session._loop_thread):

msg = ("request from non-main thread:\n{}\n"
.format('\n'.join(format_stack(None, 5)[:-1])))

self.async_call(self._err_cb, msg)
raise NvimError("request from non-main thread")

decode = kwargs.pop('decode', self._decode)
args = walk(self._to_nvim, args)
res = self._session.request(name, *args, **kwargs)
Expand Down Expand Up @@ -382,8 +392,18 @@ def out_write(self, msg, **kwargs):

def err_write(self, msg, **kwargs):
"""Print `msg` as an error message."""
if self._thread_invalid():
# special case: if a non-main thread writes to stderr
# i.e. due to an uncaught exception, pass it through
# without raising an additional exception.
self.async_call(self.err_write, msg, **kwargs)
return
return self.request('nvim_err_write', msg, **kwargs)

def _thread_invalid(self):
return (self._session._loop_thread is not None and
threading.current_thread() != self._session._loop_thread)

def quit(self, quit_command='qa!'):
"""Send a quit command to Nvim.
Expand Down
4 changes: 4 additions & 0 deletions neovim/msgpack_rpc/session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Synchronous msgpack-rpc session layer."""
import logging
import threading
from collections import deque
from traceback import format_exc

Expand Down Expand Up @@ -29,6 +30,7 @@ def __init__(self, async_session):
self._is_running = False
self._setup_exception = None
self.loop = async_session.loop
self._loop_thread = None

def threadsafe_call(self, fn, *args, **kwargs):
"""Wrapper around `AsyncSession.threadsafe_call`."""
Expand Down Expand Up @@ -110,6 +112,7 @@ def run(self, request_cb, notification_cb, setup_cb=None):
self._notification_cb = notification_cb
self._is_running = True
self._setup_exception = None
self._loop_thread = threading.current_thread()

def on_setup():
try:
Expand All @@ -135,6 +138,7 @@ def on_setup():
self._is_running = False
self._request_cb = None
self._notification_cb = None
self._loop_thread = None

if self._setup_exception:
raise self._setup_exception
Expand Down
2 changes: 2 additions & 0 deletions neovim/plugin/script_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def eval(self, expr):
# This was copied/adapted from nvim-python help
def path_hook(nvim):
def _get_paths():
if nvim._thread_invalid():
return []
return discover_runtime_directories(nvim)

def _find_module(fullname, oldtail, path):
Expand Down

0 comments on commit 0852da8

Please sign in to comment.