Skip to content

Commit

Permalink
[SPARK-44410][PYTHON][CONNECT] Set active session in create, not just…
Browse files Browse the repository at this point in the history
… getOrCreate

### What changes were proposed in this pull request?

ML and other uses rely on _active_spark_session to find spark session.

Sessions created using getOrCreate method set this variable, but sessions created with create don't.

Update create method to set _active_spark_session.

### Why are the changes needed?
This breaks spark connect customers, such as pyspark.ml and pandas from finding created session if it was created with create.

### Does this PR introduce _any_ user-facing change?
Sessions created by create are set as current session. This is slightly different behavior then before, however this
suits interest of almost all clients. The only case it might break is if someone uses mix of both `create` and `getOrCreate` relying on this exact semantic.

We can hide it under configuration flag, e.g. `create(set_active_session=False)` if undesired. In this case clients who use `create` and want to use pyspark.ml/pandas will need to update to set it to True.

### How was this patch tested?
UT

Closes apache#41987 from cdkrot/spark_session_create_store_session.

Authored-by: Alice Sayutina <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
cdkrot authored and HyukjinKwon committed Jul 16, 2023
1 parent d91904e commit ce53fdf
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
11 changes: 9 additions & 2 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@

# `_active_spark_session` stores the active spark connect session created by
# `SparkSession.builder.getOrCreate`. It is used by ML code.
# If sessions are created with `SparkSession.builder.create`, it stores
# The last created session
_active_spark_session = None


Expand Down Expand Up @@ -172,6 +174,8 @@ def enableHiveSupport(self) -> "SparkSession.Builder":
)

def create(self) -> "SparkSession":
global _active_spark_session

has_channel_builder = self._channel_builder is not None
has_spark_remote = "spark.remote" in self._options

Expand All @@ -188,11 +192,14 @@ def create(self) -> "SparkSession":

if has_channel_builder:
assert self._channel_builder is not None
return SparkSession(connection=self._channel_builder)
session = SparkSession(connection=self._channel_builder)
else:
spark_remote = to_str(self._options.get("spark.remote"))
assert spark_remote is not None
return SparkSession(connection=spark_remote)
session = SparkSession(connection=spark_remote)

_active_spark_session = session
return session

def getOrCreate(self) -> "SparkSession":
global _active_spark_session
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3323,9 +3323,9 @@ def test_can_create_multiple_sessions_to_different_remotes(self):
other = PySparkSession.builder.remote("sc://other.remote:114/").create()
self.assertNotEquals(self.spark, other)

# Reuses an active session that was previously created.
# Gets currently active session.
same = PySparkSession.builder.remote("sc://other.remote.host:114/").getOrCreate()
self.assertEquals(self.spark, same)
self.assertEquals(other, same)
same.stop()

# Make sure the environment is clean.
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/sql/tests/connect/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,10 @@ def test_session_stop(self):
self.assertFalse(session.is_stopped)
session.stop()
self.assertTrue(session.is_stopped)

def test_session_create_sets_active_session(self):
session = RemoteSparkSession.builder.remote("sc://abc").create()
session2 = RemoteSparkSession.builder.remote("sc://other").getOrCreate()

self.assertIs(session, session2)
session.stop()

0 comments on commit ce53fdf

Please sign in to comment.