Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing a custom cache_dir to tf.keras.datasets.load_data. This is helpful when the default location ~/.keras in home directory has limited disk space. #817

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/auto-assignment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
welcome:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v3
- uses: actions/github-script@v6
with:
script: |
Expand Down
9 changes: 3 additions & 6 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,15 @@ jobs:

runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: '3.10'
- uses: actions/checkout@v3

- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip setuptools
echo "::set-output name=dir::$(pip cache dir)"
- name: pip cache
uses: actions/cache@v3
uses: actions/cache@v2
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
Expand Down
9 changes: 3 additions & 6 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,15 @@ jobs:
name: Check the code format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: '3.10'
- uses: actions/checkout@v3

- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip setuptools
echo "::set-output name=dir::$(pip cache dir)"
- name: pip cache
uses: actions/cache@v3
uses: actions/cache@v2
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.boston_housing"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'path\', \'test_split\', \'seed\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\'], "
argspec: "args=[\'path\', \'test_split\', \'seed\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.cifar10"
tf_module {
member_method {
name: "load_data"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'cache_dir\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.cifar100"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'label_mode\'], varargs=None, keywords=None, defaults=[\'fine\'], "
argspec: "args=[\'label_mode\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'fine\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.fashion_mnist"
tf_module {
member_method {
name: "load_data"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'cache_dir\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ tf_module {
}
member_method {
name: "load_data"
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\'], "
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\', \'cache_dir\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.mnist"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=[\'mnist.npz\'], "
argspec: "args=[\'path\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'mnist.npz\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ tf_module {
}
member_method {
name: "load_data"
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\'], "
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\', \'cache_dir\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.boston_housing"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'path\', \'test_split\', \'seed\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\'], "
argspec: "args=[\'path\', \'test_split\', \'seed\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.cifar10"
tf_module {
member_method {
name: "load_data"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'cache_dir\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.cifar100"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'label_mode\'], varargs=None, keywords=None, defaults=[\'fine\'], "
argspec: "args=[\'label_mode\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'fine\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.fashion_mnist"
tf_module {
member_method {
name: "load_data"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'cache_dir\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ tf_module {
}
member_method {
name: "load_data"
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\'], "
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\', \'cache_dir\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.mnist"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=[\'mnist.npz\'], "
argspec: "args=[\'path\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'mnist.npz\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ tf_module {
}
member_method {
name: "load_data"
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\'], "
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\', \'cache_dir\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\', \'None\'], "
}
}
19 changes: 14 additions & 5 deletions tf_keras/datasets/boston_housing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
"""Boston housing price regression dataset."""

import os

import numpy as np

from tf_keras.utils.data_utils import get_file
Expand All @@ -23,7 +25,9 @@


@keras_export("keras.datasets.boston_housing.load_data")
def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
def load_data(
path="boston_housing.npz", test_split=0.2, seed=113, cache_dir=None
):
"""Loads the Boston Housing dataset.

This is a dataset taken from the StatLib library which is maintained at
Expand All @@ -43,11 +47,12 @@ def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
[StatLib website](http://lib.stat.cmu.edu/datasets/boston).

Args:
path: path where to cache the dataset locally
(relative to `~/.keras/datasets`).
path: path where to cache the dataset locally (relative to
`~/.keras/datasets`).
test_split: fraction of the data to reserve as test set.
seed: Random seed for shuffling the data
before computing the test split.
seed: Random seed for shuffling the data before computing the test split.
cache_dir: directory where to cache the dataset locally. When None,
defaults to `~/.keras/datasets`.

Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
Expand All @@ -64,12 +69,16 @@ def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
origin_folder = (
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
)
if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
path = get_file(
path,
origin=origin_folder + "boston_housing.npz",
file_hash=( # noqa: E501
"f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5"
),
cache_dir=cache_dir,
)
with np.load(path, allow_pickle=True) as f:
x = f["x"]
Expand Down
10 changes: 9 additions & 1 deletion tf_keras/datasets/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


@keras_export("keras.datasets.cifar10.load_data")
def load_data():
def load_data(cache_dir=None):
"""Loads the CIFAR10 dataset.

This is a dataset of 50,000 32x32 color training images and 10,000 test
Expand All @@ -49,6 +49,10 @@ def load_data():
| 8 | ship |
| 9 | truck |

Args:
cache_dir: directory where to cache the dataset locally. When None,
defaults to `~/.keras/datasets`.

Returns:
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.

Expand Down Expand Up @@ -78,13 +82,17 @@ def load_data():
"""
dirname = "cifar-10-batches-py"
origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
path = get_file(
dirname,
origin=origin,
untar=True,
file_hash=( # noqa: E501
"6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
),
cache_dir=cache_dir,
)

num_train_samples = 50000
Expand Down
8 changes: 7 additions & 1 deletion tf_keras/datasets/cifar100.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


@keras_export("keras.datasets.cifar100.load_data")
def load_data(label_mode="fine"):
def load_data(label_mode="fine", cache_dir=None):
"""Loads the CIFAR100 dataset.

This is a dataset of 50,000 32x32 color training images and
Expand All @@ -39,6 +39,8 @@ def load_data(label_mode="fine"):
label_mode: one of "fine", "coarse". If it is "fine" the category labels
are the fine-grained labels, if it is "coarse" the output labels are the
coarse-grained superclasses.
cache_dir: directory where to cache the dataset locally. When None,
defaults to `~/.keras/datasets`.

Returns:
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
Expand Down Expand Up @@ -75,13 +77,17 @@ def load_data(label_mode="fine"):

dirname = "cifar-100-python"
origin = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
path = get_file(
dirname,
origin=origin,
untar=True,
file_hash=( # noqa: E501
"85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
),
cache_dir=cache_dir,
)

fpath = os.path.join(path, "train")
Expand Down
20 changes: 16 additions & 4 deletions tf_keras/datasets/fashion_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


@keras_export("keras.datasets.fashion_mnist.load_data")
def load_data():
def load_data(cache_dir=None):
"""Loads the Fashion-MNIST dataset.

This is a dataset of 60,000 28x28 grayscale images of 10 fashion categories,
Expand All @@ -48,6 +48,10 @@ def load_data():
| 8 | Bag |
| 9 | Ankle boot |

Args:
cache_dir: directory where to cache the dataset locally. When None,
defaults to `~/.keras/datasets`.

Returns:
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.

Expand Down Expand Up @@ -77,7 +81,6 @@ def load_data():
The copyright for Fashion-MNIST is held by Zalando SE.
Fashion-MNIST is licensed under the [MIT license](
https://github.com/zalandoresearch/fashion-mnist/blob/master/LICENSE).

"""
dirname = os.path.join("datasets", "fashion-mnist")
base = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
Expand All @@ -87,10 +90,19 @@ def load_data():
"t10k-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz",
]

if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
paths = []
for fname in files:
paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname))
paths.append(
get_file(
fname,
origin=base + fname,
cache_dir=cache_dir,
cache_subdir=dirname,
)
)

with gzip.open(paths[0], "rb") as lbpath:
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
Expand Down
8 changes: 8 additions & 0 deletions tf_keras/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""IMDB sentiment classification dataset."""

import json
import os

import numpy as np

Expand All @@ -36,6 +37,7 @@ def load_data(
start_char=1,
oov_char=2,
index_from=3,
cache_dir=None,
**kwargs,
):
"""Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).
Expand Down Expand Up @@ -73,6 +75,8 @@ def load_data(
Words that were cut out because of the `num_words` or
`skip_top` limits will be replaced with this character.
index_from: int. Index actual words with this index and higher.
cache_dir: directory where to cache the dataset locally. When None,
defaults to `~/.keras/datasets`.
**kwargs: Used for backwards compatibility.

Returns:
Expand Down Expand Up @@ -108,12 +112,16 @@ def load_data(
origin_folder = (
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
)
if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
path = get_file(
path,
origin=origin_folder + "imdb.npz",
file_hash=( # noqa: E501
"69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f"
),
cache_dir=cache_dir,
)
with np.load(path, allow_pickle=True) as f:
x_train, labels_train = f["x_train"], f["y_train"]
Expand Down
Loading
Loading