Skip to content

Commit

Permalink
Use tf.function for list column operations (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
edknv committed Dec 27, 2022
1 parent 42dd301 commit fd5d3fc
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions tests/unit/dataloader/test_tf_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,25 +94,36 @@ def test_nested_list():
schema = ds.schema
schema["label"] = schema["label"].with_tags([Tags.TARGET])
ds.schema = schema
train_dataset = tf_dataloader.Loader(
loader = tf_dataloader.Loader(
ds,
batch_size=batch_size,
shuffle=False,
)

batch = next(train_dataset)
batch = next(loader)

# [[1,2,3],[3,1],[...],[]]
nested_data_col = tf.RaggedTensor.from_row_lengths(
batch[0]["data"][0][:, 0], tf.cast(batch[0]["data"][1][:, 0], tf.int32)
).to_tensor()
@tf.function
def _ragged_for_nested_data_col():
nested_data_col = tf.RaggedTensor.from_row_lengths(
batch[0]["data"][0][:, 0], tf.cast(batch[0]["data"][1][:, 0], tf.int32)
).to_tensor()
return nested_data_col

nested_data_col = _ragged_for_nested_data_col()
true_data_col = tf.reshape(
tf.ragged.constant(df.iloc[:batch_size, 0].tolist()).to_tensor(),
[batch_size, -1],
tf.ragged.constant(df.iloc[:batch_size, 0].tolist()).to_tensor(), [batch_size, -1]
)

# [1,2,3]
multihot_data2_col = tf.RaggedTensor.from_row_lengths(
batch[0]["data2"][0][:, 0], tf.cast(batch[0]["data2"][1][:, 0], tf.int32)
).to_tensor()
@tf.function
def _ragged_for_multihot_data_col():
multihot_data2_col = tf.RaggedTensor.from_row_lengths(
batch[0]["data2"][0][:, 0], tf.cast(batch[0]["data2"][1][:, 0], tf.int32)
).to_tensor()
return multihot_data2_col

multihot_data2_col = _ragged_for_multihot_data_col()
true_data2_col = tf.reshape(
tf.ragged.constant(df.iloc[:batch_size, 1].tolist()).to_tensor(),
[batch_size, -1],
Expand Down

0 comments on commit fd5d3fc

Please sign in to comment.