Skip to content

Commit

Permalink
Allow passing a custom cache_dir to tf.keras.datasets.load_data. This…
Browse files Browse the repository at this point in the history
… is helpful when the default location `~/.keras` in home directory has limited disk space.

PiperOrigin-RevId: 713015638
  • Loading branch information
sampathweb authored and tensorflower-gardener committed Jan 8, 2025
1 parent 6f44991 commit 39b77af
Show file tree
Hide file tree
Showing 22 changed files with 86 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
${{ runner.os }}-pip-
- name: Install dependencies
run: |
pip install black==22.3.0 isort==5.10.1 flake8==4.0.1
pip install black==22.3.0 isort==5.10.1 flake8==4.0.1 'importlib_metadata<5'
- name: Format the code
run: sh shell/format.sh

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.boston_housing"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'path\', \'test_split\', \'seed\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\'], "
argspec: "args=[\'path\', \'test_split\', \'seed\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.cifar10"
tf_module {
member_method {
name: "load_data"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'cache_dir\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.cifar100"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'label_mode\'], varargs=None, keywords=None, defaults=[\'fine\'], "
argspec: "args=[\'label_mode\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'fine\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.fashion_mnist"
tf_module {
member_method {
name: "load_data"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'cache_dir\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ tf_module {
}
member_method {
name: "load_data"
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\'], "
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\', \'cache_dir\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.mnist"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=[\'mnist.npz\'], "
argspec: "args=[\'path\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'mnist.npz\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ tf_module {
}
member_method {
name: "load_data"
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\'], "
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\', \'cache_dir\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.boston_housing"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'path\', \'test_split\', \'seed\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\'], "
argspec: "args=[\'path\', \'test_split\', \'seed\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.cifar10"
tf_module {
member_method {
name: "load_data"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'cache_dir\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.cifar100"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'label_mode\'], varargs=None, keywords=None, defaults=[\'fine\'], "
argspec: "args=[\'label_mode\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'fine\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.fashion_mnist"
tf_module {
member_method {
name: "load_data"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'cache_dir\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ tf_module {
}
member_method {
name: "load_data"
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\'], "
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\', \'cache_dir\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.mnist"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=[\'mnist.npz\'], "
argspec: "args=[\'path\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'mnist.npz\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ tf_module {
}
member_method {
name: "load_data"
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\'], "
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\', \'cache_dir\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\', \'None\'], "
}
}
19 changes: 14 additions & 5 deletions tf_keras/datasets/boston_housing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
"""Boston housing price regression dataset."""

import os

import numpy as np

from tf_keras.utils.data_utils import get_file
Expand All @@ -23,7 +25,9 @@


@keras_export("keras.datasets.boston_housing.load_data")
def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
def load_data(
path="boston_housing.npz", test_split=0.2, seed=113, cache_dir=None
):
"""Loads the Boston Housing dataset.
This is a dataset taken from the StatLib library which is maintained at
Expand All @@ -43,11 +47,12 @@ def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
[StatLib website](http://lib.stat.cmu.edu/datasets/boston).
Args:
path: path where to cache the dataset locally
(relative to `~/.keras/datasets`).
path: path where to cache the dataset locally (relative to
`~/.keras/datasets`).
test_split: fraction of the data to reserve as test set.
seed: Random seed for shuffling the data
before computing the test split.
seed: Random seed for shuffling the data before computing the test split.
cache_dir: directory where to cache the dataset locally. When None,
defaults to `~/.keras/datasets`.
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
Expand All @@ -64,12 +69,16 @@ def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
origin_folder = (
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
)
if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
path = get_file(
path,
origin=origin_folder + "boston_housing.npz",
file_hash=( # noqa: E501
"f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5"
),
cache_dir=cache_dir,
)
with np.load(path, allow_pickle=True) as f:
x = f["x"]
Expand Down
10 changes: 9 additions & 1 deletion tf_keras/datasets/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


@keras_export("keras.datasets.cifar10.load_data")
def load_data():
def load_data(cache_dir=None):
"""Loads the CIFAR10 dataset.
This is a dataset of 50,000 32x32 color training images and 10,000 test
Expand All @@ -49,6 +49,10 @@ def load_data():
| 8 | ship |
| 9 | truck |
Args:
cache_dir: directory where to cache the dataset locally. When None,
defaults to `~/.keras/datasets`.
Returns:
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
Expand Down Expand Up @@ -78,13 +82,17 @@ def load_data():
"""
dirname = "cifar-10-batches-py"
origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
path = get_file(
dirname,
origin=origin,
untar=True,
file_hash=( # noqa: E501
"6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
),
cache_dir=cache_dir,
)

num_train_samples = 50000
Expand Down
8 changes: 7 additions & 1 deletion tf_keras/datasets/cifar100.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


@keras_export("keras.datasets.cifar100.load_data")
def load_data(label_mode="fine"):
def load_data(label_mode="fine", cache_dir=None):
"""Loads the CIFAR100 dataset.
This is a dataset of 50,000 32x32 color training images and
Expand All @@ -39,6 +39,8 @@ def load_data(label_mode="fine"):
label_mode: one of "fine", "coarse". If it is "fine" the category labels
are the fine-grained labels, if it is "coarse" the output labels are the
coarse-grained superclasses.
cache_dir: directory where to cache the dataset locally. When None,
defaults to `~/.keras/datasets`.
Returns:
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
Expand Down Expand Up @@ -75,13 +77,17 @@ def load_data(label_mode="fine"):

dirname = "cifar-100-python"
origin = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
path = get_file(
dirname,
origin=origin,
untar=True,
file_hash=( # noqa: E501
"85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
),
cache_dir=cache_dir,
)

fpath = os.path.join(path, "train")
Expand Down
20 changes: 16 additions & 4 deletions tf_keras/datasets/fashion_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


@keras_export("keras.datasets.fashion_mnist.load_data")
def load_data():
def load_data(cache_dir=None):
"""Loads the Fashion-MNIST dataset.
This is a dataset of 60,000 28x28 grayscale images of 10 fashion categories,
Expand All @@ -48,6 +48,10 @@ def load_data():
| 8 | Bag |
| 9 | Ankle boot |
Args:
cache_dir: directory where to cache the dataset locally. When None,
defaults to `~/.keras/datasets`.
Returns:
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
Expand Down Expand Up @@ -77,7 +81,6 @@ def load_data():
The copyright for Fashion-MNIST is held by Zalando SE.
Fashion-MNIST is licensed under the [MIT license](
https://github.com/zalandoresearch/fashion-mnist/blob/master/LICENSE).
"""
dirname = os.path.join("datasets", "fashion-mnist")
base = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
Expand All @@ -87,10 +90,19 @@ def load_data():
"t10k-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz",
]

if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
paths = []
for fname in files:
paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname))
paths.append(
get_file(
fname,
origin=base + fname,
cache_dir=cache_dir,
cache_subdir=dirname,
)
)

with gzip.open(paths[0], "rb") as lbpath:
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
Expand Down
8 changes: 8 additions & 0 deletions tf_keras/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""IMDB sentiment classification dataset."""

import json
import os

import numpy as np

Expand All @@ -36,6 +37,7 @@ def load_data(
start_char=1,
oov_char=2,
index_from=3,
cache_dir=None,
**kwargs,
):
"""Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).
Expand Down Expand Up @@ -73,6 +75,8 @@ def load_data(
Words that were cut out because of the `num_words` or
`skip_top` limits will be replaced with this character.
index_from: int. Index actual words with this index and higher.
cache_dir: directory where to cache the dataset locally. When None,
defaults to `~/.keras/datasets`.
**kwargs: Used for backwards compatibility.
Returns:
Expand Down Expand Up @@ -108,12 +112,16 @@ def load_data(
origin_folder = (
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
)
if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
path = get_file(
path,
origin=origin_folder + "imdb.npz",
file_hash=( # noqa: E501
"69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f"
),
cache_dir=cache_dir,
)
with np.load(path, allow_pickle=True) as f:
x_train, labels_train = f["x_train"], f["y_train"]
Expand Down
12 changes: 9 additions & 3 deletions tf_keras/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""MNIST handwritten digits dataset."""
import os

import numpy as np

Expand All @@ -23,7 +24,7 @@


@keras_export("keras.datasets.mnist.load_data")
def load_data(path="mnist.npz"):
def load_data(path="mnist.npz", cache_dir=None):
"""Loads the MNIST dataset.
This is a dataset of 60,000 28x28 grayscale images of the 10 digits,
Expand All @@ -32,8 +33,9 @@ def load_data(path="mnist.npz"):
[MNIST homepage](http://yann.lecun.com/exdb/mnist/).
Args:
path: path where to cache the dataset locally
(relative to `~/.keras/datasets`).
path: path where to cache the dataset locally relative to cache_dir.
cache_dir: directory where to cache the dataset locally. When None,
defaults to `~/.keras/datasets`.
Returns:
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
Expand Down Expand Up @@ -72,12 +74,16 @@ def load_data(path="mnist.npz"):
origin_folder = (
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
)
if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
path = get_file(
path,
origin=origin_folder + "mnist.npz",
file_hash=( # noqa: E501
"731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"
),
cache_dir=cache_dir,
)
with np.load(path, allow_pickle=True) as f:
x_train, y_train = f["x_train"], f["y_train"]
Expand Down
Loading

0 comments on commit 39b77af

Please sign in to comment.