Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-41794][SQL] Add try_remainder function and re-enable column tests #46434

Closed
Original file line number Diff line number Diff line change
Expand Up @@ -1932,6 +1932,14 @@ object functions {
*/
def try_divide(left: Column, right: Column): Column = Column.fn("try_divide", left, right)

/**
* Returns the remainder of `dividend``/``divisor`. Its result is always null if `divisor` is 0.
*
* @group math_funcs
* @since 4.0.0
*/
def try_remainder(left: Column, right: Column): Column = Column.fn("try_remainder", left, right)

/**
* Returns `left``*``right` and the result is null on overflow. The acceptable input types are
* the same with the `*` operator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ class StreamingQueryListenerBus(sparkSession: SparkSession) extends Logging {
}
} catch {
case e: Exception =>
logWarning("StreamingQueryListenerBus Handler thread received exception, all client" +
" side listeners are removed and handler thread is terminated.", e)
logWarning(
"StreamingQueryListenerBus Handler thread received exception, all client" +
" side listeners are removed and handler thread is terminated.",
e)
Comment on lines +124 to +127
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logWarning(
"StreamingQueryListenerBus Handler thread received exception, all client" +
" side listeners are removed and handler thread is terminated.",
e)
logWarning("StreamingQueryListenerBus Handler thread received exception, all client" +
" side listeners are removed and handler thread is terminated.", e)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did the spark connect auto format and it produced these changes. I'm ok with reverting them in the worst case but at the same time your suggestions introduce na manual style adjustment.

LMK

lock.synchronized {
executionThread = Option.empty
listeners.forEach(remove(_))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,7 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.streaming.RemoteStreamingQuery$"),
// Skip client side listener specific class
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.streaming.StreamingQueryListenerBus"
),
"org.apache.spark.sql.streaming.StreamingQueryListenerBus"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"org.apache.spark.sql.streaming.StreamingQueryListenerBus"),
"org.apache.spark.sql.streaming.StreamingQueryListenerBus"
),


// Encoders are in the wrong JAR
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders"),
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,13 @@ def try_divide(left: "ColumnOrName", right: "ColumnOrName") -> Column:
try_divide.__doc__ = pysparkfuncs.try_divide.__doc__


def try_remainder(left: "ColumnOrName", right: "ColumnOrName") -> Column:
return _invoke_function_over_columns("try_remainder", left, right)


try_remainder.__doc__ = pysparkfuncs.try_remainder.__doc__


def try_multiply(left: "ColumnOrName", right: "ColumnOrName") -> Column:
return _invoke_function_over_columns("try_multiply", left, right)

Expand Down
52 changes: 51 additions & 1 deletion python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def try_divide(left: "ColumnOrName", right: "ColumnOrName") -> Column:
| 4 months|
+--------------------------------------------------+

Example 3: Exception druing division, resulting in NULL when ANSI mode is on
Example 3: Exception during division, resulting in NULL when ANSI mode is on

>>> import pyspark.sql.functions as sf
>>> origin = spark.conf.get("spark.sql.ansi.enabled")
Expand All @@ -657,6 +657,56 @@ def try_divide(left: "ColumnOrName", right: "ColumnOrName") -> Column:
return _invoke_function_over_columns("try_divide", left, right)


@_try_remote_functions
def try_remainder(left: "ColumnOrName", right: "ColumnOrName") -> Column:
"""
Returns the remainder after `dividend`/`divisor`. Its result is
always null if `divisor` is 0.

.. versionadded:: 4.0.0

Parameters
----------
left : :class:`~pyspark.sql.Column` or str
dividend
right : :class:`~pyspark.sql.Column` or str
divisor

Examples
--------
Example 1: Integer divided by Integer.

>>> import pyspark.sql.functions as sf
>>> spark.createDataFrame(
... [(6000, 15), (3, 2), (1234, 0)], ["a", "b"]
... ).select(sf.try_remainder("a", "b")).show()
+-------------------+
|try_remainder(a, b)|
+-------------------+
| 0|
| 1|
| NULL|
+-------------------+

Example 2: Exception during division, resulting in NULL when ANSI mode is on

>>> import pyspark.sql.functions as sf
>>> origin = spark.conf.get("spark.sql.ansi.enabled")
>>> spark.conf.set("spark.sql.ansi.enabled", "true")
>>> try:
... df = spark.range(1)
... df.select(sf.try_remainder(df.id, sf.lit(0))).show()
... finally:
... spark.conf.set("spark.sql.ansi.enabled", origin)
+--------------------+
|try_remainder(id, 0)|
+--------------------+
| NULL|
+--------------------+
"""
return _invoke_function_over_columns("try_remainder", left, right)


@_try_remote_functions
def try_multiply(left: "ColumnOrName", right: "ColumnOrName") -> Column:
"""
Expand Down
16 changes: 8 additions & 8 deletions python/pyspark/sql/tests/connect/test_connect_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,8 +772,8 @@ def test_column_accessor(self):
sdf.select(sdf.z[0], sdf.z[1], sdf["z"][2]).toPandas(),
)
self.assert_eq(
cdf.select(CF.col("z")[0], cdf.z[10], CF.col("z")[-10]).toPandas(),
sdf.select(SF.col("z")[0], sdf.z[10], SF.col("z")[-10]).toPandas(),
cdf.select(CF.col("z")[0], CF.get(cdf.z, 10), CF.get(CF.col("z"), -10)).toPandas(),
sdf.select(SF.col("z")[0], SF.get(sdf.z, 10), SF.get(SF.col("z"), -10)).toPandas(),
)
self.assert_eq(
cdf.select(cdf.z.getItem(0), cdf.z.getItem(1), cdf["z"].getField(2)).toPandas(),
Expand Down Expand Up @@ -824,8 +824,12 @@ def test_column_arithmetic_ops(self):
)

self.assert_eq(
cdf.select(cdf.a % cdf["b"], cdf["a"] % 2, 12 % cdf.c).toPandas(),
sdf.select(sdf.a % sdf["b"], sdf["a"] % 2, 12 % sdf.c).toPandas(),
cdf.select(
cdf.a % cdf["b"], cdf["a"] % 2, CF.try_remainder(CF.lit(12), cdf.c)
).toPandas(),
sdf.select(
sdf.a % sdf["b"], sdf["a"] % 2, SF.try_remainder(SF.lit(12), sdf.c)
).toPandas(),
)

self.assert_eq(
Expand Down Expand Up @@ -1022,13 +1026,9 @@ def test_distributed_sequence_id(self):


if __name__ == "__main__":
import os
import unittest
from pyspark.sql.tests.connect.test_connect_column import * # noqa: F401

# TODO(SPARK-41794): Enable ANSI mode in this file.
os.environ["SPARK_ANSI_SQL_MODE"] = "false"

try:
import xmlrunner

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ object FunctionRegistry {
// "try_*" function which always return Null instead of runtime error.
expression[TryAdd]("try_add"),
expression[TryDivide]("try_divide"),
expression[TryRemainder]("try_remainder"),
expression[TrySubtract]("try_subtract"),
expression[TryMultiply]("try_multiply"),
expression[TryElementAt]("try_element_at"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,43 @@ case class TryDivide(left: Expression, right: Expression, replacement: Expressio
}
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(dividend, divisor) - Returns the remainder after `expr1`/`expr2`. " +
"`dividend` must be a numeric. `divisor` must be a numeric.",
examples = """
Examples:
> SELECT _FUNC_(3, 2);
1
> SELECT _FUNC_(2L, 2L);
0
> SELECT _FUNC_(3.0, 2.0);
1.0
> SELECT _FUNC_(1, 0);
NULL
""",
since = "3.2.0",
grundprinzip marked this conversation as resolved.
Show resolved Hide resolved
group = "math_funcs")
// scalastyle:on line.size.limit
case class TryRemainder(left: Expression, right: Expression, replacement: Expression)
extends RuntimeReplaceable with InheritAnalysisRules {
def this(left: Expression, right: Expression) = this(left, right,
(left.dataType, right.dataType) match {
case (_: NumericType, _: NumericType) => Remainder(left, right, EvalMode.TRY)
// TODO: support TRY eval mode on datetime arithmetic expressions.
case _ => TryEval(Remainder(left, right, EvalMode.ANSI))
}
)

override def prettyName: String = "try_remainder"

override def parameters: Seq[Expression] = Seq(left, right)

override protected def withNewChildInternal(newChild: Expression): Expression = {
copy(replacement = newChild)
}
}

@ExpressionDescription(
usage = "_FUNC_(expr1, expr2) - Returns `expr1`-`expr2` and the result is null on overflow. " +
"The acceptable input types are the same with the `-` operator.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,10 @@ case class Remainder(

override def inputType: AbstractDataType = NumericType

// `try_remainder` has exactly the same behavior as the legacy divide, so here it only executes
// the error code path when `evalMode` is `ANSI`.
protected override def failOnError: Boolean = evalMode == EvalMode.ANSI

override def symbol: String = "%"
override def decimalMethod: String = "remainder"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,19 @@ class TryEvalSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("try_remainder") {
Seq(
(3.0, 2.0, 1.0),
(1.0, 0.0, null),
(-1.0, 0.0, null)
).foreach { case (a, b, expected) =>
val left = Literal(a)
val right = Literal(b)
val input = Remainder(left, right, EvalMode.TRY)
checkEvaluation(input, expected)
}
}

test("try_subtract") {
Seq(
(1, 1, 0),
Expand Down
9 changes: 9 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1937,6 +1937,15 @@ object functions {
*/
def try_divide(left: Column, right: Column): Column = Column.fn("try_divide", left, right)

/**
* Returns the remainder of `dividend``/``divisor`. Its result is
* always null if `divisor` is 0.
*
* @group math_funcs
* @since 4.0.0
*/
def try_remainder(left: Column, right: Column): Column = Column.fn("try_remainder", left, right)

/**
* Returns `left``*``right` and the result is null on overflow. The acceptable input types are
* the same with the `*` operator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@
| org.apache.spark.sql.catalyst.expressions.TryElementAt | try_element_at | SELECT try_element_at(array(1, 2, 3), 2) | struct<try_element_at(array(1, 2, 3), 2):int> |
| org.apache.spark.sql.catalyst.expressions.TryMultiply | try_multiply | SELECT try_multiply(2, 3) | struct<try_multiply(2, 3):int> |
| org.apache.spark.sql.catalyst.expressions.TryReflect | try_reflect | SELECT try_reflect('java.util.UUID', 'randomUUID') | struct<try_reflect(java.util.UUID, randomUUID):string> |
| org.apache.spark.sql.catalyst.expressions.TryRemainder | try_remainder | SELECT try_remainder(3, 2) | struct<try_remainder(3, 2):int> |
| org.apache.spark.sql.catalyst.expressions.TrySubtract | try_subtract | SELECT try_subtract(2, 1) | struct<try_subtract(2, 1):int> |
| org.apache.spark.sql.catalyst.expressions.TryToBinary | try_to_binary | SELECT try_to_binary('abc', 'utf-8') | struct<try_to_binary(abc, utf-8):binary> |
| org.apache.spark.sql.catalyst.expressions.TryToNumber | try_to_number | SELECT try_to_number('454', '999') | struct<try_to_number(454, 999):decimal(3,0)> |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,11 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
df1.select(try_divide(make_interval(col("year"), col("month")), lit(0))))
}

test("try_remainder") {
val df = Seq((10, 3), (5, 5), (5, 0)).toDF("birth", "age")
grundprinzip marked this conversation as resolved.
Show resolved Hide resolved
checkAnswer(df.selectExpr("try_remainder(birth, age)"), Seq(Row(1), Row(0), Row(null)))
}

test("try_element_at") {
val df = Seq((Array(1, 2, 3), 2)).toDF("a", "b")
checkAnswer(df.selectExpr("try_element_at(a, b)"), Seq(Row(2)))
Expand Down