Skip to content

Commit

Permalink
Allow passing a custom cache_dir to tf.keras.datasets.load_data. This…
Browse files Browse the repository at this point in the history
… is helpful when the default location `~/.keras` in home directory has limited disk space.

PiperOrigin-RevId: 713015638
  • Loading branch information
sampathweb authored and tensorflower-gardener committed Jan 7, 2025
1 parent 6f44991 commit 8fdeeea
Show file tree
Hide file tree
Showing 21 changed files with 85 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.boston_housing"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'path\', \'test_split\', \'seed\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\'], "
argspec: "args=[\'path\', \'test_split\', \'seed\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.cifar10"
tf_module {
member_method {
name: "load_data"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'cache_dir\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.cifar100"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'label_mode\'], varargs=None, keywords=None, defaults=[\'fine\'], "
argspec: "args=[\'label_mode\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'fine\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.fashion_mnist"
tf_module {
member_method {
name: "load_data"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'cache_dir\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ tf_module {
}
member_method {
name: "load_data"
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\'], "
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\', \'cache_dir\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.mnist"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=[\'mnist.npz\'], "
argspec: "args=[\'path\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'mnist.npz\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ tf_module {
}
member_method {
name: "load_data"
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\'], "
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\', \'cache_dir\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.boston_housing"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'path\', \'test_split\', \'seed\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\'], "
argspec: "args=[\'path\', \'test_split\', \'seed\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.cifar10"
tf_module {
member_method {
name: "load_data"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'cache_dir\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.cifar100"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'label_mode\'], varargs=None, keywords=None, defaults=[\'fine\'], "
argspec: "args=[\'label_mode\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'fine\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.fashion_mnist"
tf_module {
member_method {
name: "load_data"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'cache_dir\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ tf_module {
}
member_method {
name: "load_data"
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\'], "
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\', \'cache_dir\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.mnist"
tf_module {
member_method {
name: "load_data"
argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=[\'mnist.npz\'], "
argspec: "args=[\'path\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'mnist.npz\', \'None\'], "
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ tf_module {
}
member_method {
name: "load_data"
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\'], "
argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\', \'cache_dir\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\', \'None\'], "
}
}
19 changes: 14 additions & 5 deletions tf_keras/datasets/boston_housing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
"""Boston housing price regression dataset."""

import os

import numpy as np

from tf_keras.utils.data_utils import get_file
Expand All @@ -23,7 +25,9 @@


@keras_export("keras.datasets.boston_housing.load_data")
def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
def load_data(
path="boston_housing.npz", test_split=0.2, seed=113, cache_dir=None
):
"""Loads the Boston Housing dataset.
This is a dataset taken from the StatLib library which is maintained at
Expand All @@ -43,11 +47,12 @@ def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
[StatLib website](http://lib.stat.cmu.edu/datasets/boston).
Args:
path: path where to cache the dataset locally
(relative to `~/.keras/datasets`).
path: path where to cache the dataset locally (relative to
`~/.keras/datasets`).
test_split: fraction of the data to reserve as test set.
seed: Random seed for shuffling the data
before computing the test split.
seed: Random seed for shuffling the data before computing the test split.
cache_dir: directory where to cache the dataset locally. When None,
defaults to `~/.keras/datasets`.
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
Expand All @@ -64,12 +69,16 @@ def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
origin_folder = (
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
)
if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
path = get_file(
path,
origin=origin_folder + "boston_housing.npz",
file_hash=( # noqa: E501
"f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5"
),
cache_dir=cache_dir,
)
with np.load(path, allow_pickle=True) as f:
x = f["x"]
Expand Down
10 changes: 9 additions & 1 deletion tf_keras/datasets/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


@keras_export("keras.datasets.cifar10.load_data")
def load_data():
def load_data(cache_dir=None):
"""Loads the CIFAR10 dataset.
This is a dataset of 50,000 32x32 color training images and 10,000 test
Expand All @@ -49,6 +49,10 @@ def load_data():
| 8 | ship |
| 9 | truck |
Args:
cache_dir: directory where to cache the dataset locally. When None,
defaults to `~/.keras/datasets`.
Returns:
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
Expand Down Expand Up @@ -78,13 +82,17 @@ def load_data():
"""
dirname = "cifar-10-batches-py"
origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
path = get_file(
dirname,
origin=origin,
untar=True,
file_hash=( # noqa: E501
"6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
),
cache_dir=cache_dir,
)

num_train_samples = 50000
Expand Down
8 changes: 7 additions & 1 deletion tf_keras/datasets/cifar100.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


@keras_export("keras.datasets.cifar100.load_data")
def load_data(label_mode="fine"):
def load_data(label_mode="fine", cache_dir=None):
"""Loads the CIFAR100 dataset.
This is a dataset of 50,000 32x32 color training images and
Expand All @@ -39,6 +39,8 @@ def load_data(label_mode="fine"):
label_mode: one of "fine", "coarse". If it is "fine" the category labels
are the fine-grained labels, if it is "coarse" the output labels are the
coarse-grained superclasses.
cache_dir: directory where to cache the dataset locally. When None,
defaults to `~/.keras/datasets`.
Returns:
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
Expand Down Expand Up @@ -75,13 +77,17 @@ def load_data(label_mode="fine"):

dirname = "cifar-100-python"
origin = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
path = get_file(
dirname,
origin=origin,
untar=True,
file_hash=( # noqa: E501
"85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7"
),
cache_dir=cache_dir,
)

fpath = os.path.join(path, "train")
Expand Down
20 changes: 16 additions & 4 deletions tf_keras/datasets/fashion_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


@keras_export("keras.datasets.fashion_mnist.load_data")
def load_data():
def load_data(cache_dir=None):
"""Loads the Fashion-MNIST dataset.
This is a dataset of 60,000 28x28 grayscale images of 10 fashion categories,
Expand All @@ -48,6 +48,10 @@ def load_data():
| 8 | Bag |
| 9 | Ankle boot |
Args:
cache_dir: directory where to cache the dataset locally. When None,
defaults to `~/.keras/datasets`.
Returns:
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
Expand Down Expand Up @@ -77,7 +81,6 @@ def load_data():
The copyright for Fashion-MNIST is held by Zalando SE.
Fashion-MNIST is licensed under the [MIT license](
https://github.com/zalandoresearch/fashion-mnist/blob/master/LICENSE).
"""
dirname = os.path.join("datasets", "fashion-mnist")
base = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
Expand All @@ -87,10 +90,19 @@ def load_data():
"t10k-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz",
]

if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
paths = []
for fname in files:
paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname))
paths.append(
get_file(
fname,
origin=base + fname,
cache_dir=cache_dir,
cache_subdir=dirname,
)
)

with gzip.open(paths[0], "rb") as lbpath:
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
Expand Down
8 changes: 8 additions & 0 deletions tf_keras/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""IMDB sentiment classification dataset."""

import json
import os

import numpy as np

Expand All @@ -36,6 +37,7 @@ def load_data(
start_char=1,
oov_char=2,
index_from=3,
cache_dir=None,
**kwargs,
):
"""Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).
Expand Down Expand Up @@ -73,6 +75,8 @@ def load_data(
Words that were cut out because of the `num_words` or
`skip_top` limits will be replaced with this character.
index_from: int. Index actual words with this index and higher.
cache_dir: directory where to cache the dataset locally. When None,
defaults to `~/.keras/datasets`.
**kwargs: Used for backwards compatibility.
Returns:
Expand Down Expand Up @@ -108,12 +112,16 @@ def load_data(
origin_folder = (
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
)
if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
path = get_file(
path,
origin=origin_folder + "imdb.npz",
file_hash=( # noqa: E501
"69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f"
),
cache_dir=cache_dir,
)
with np.load(path, allow_pickle=True) as f:
x_train, labels_train = f["x_train"], f["y_train"]
Expand Down
12 changes: 9 additions & 3 deletions tf_keras/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""MNIST handwritten digits dataset."""
import os

import numpy as np

Expand All @@ -23,7 +24,7 @@


@keras_export("keras.datasets.mnist.load_data")
def load_data(path="mnist.npz"):
def load_data(path="mnist.npz", cache_dir=None):
"""Loads the MNIST dataset.
This is a dataset of 60,000 28x28 grayscale images of the 10 digits,
Expand All @@ -32,8 +33,9 @@ def load_data(path="mnist.npz"):
[MNIST homepage](http://yann.lecun.com/exdb/mnist/).
Args:
path: path where to cache the dataset locally
(relative to `~/.keras/datasets`).
path: path where to cache the dataset locally relative to cache_dir.
cache_dir: directory where to cache the dataset locally. When None,
defaults to `~/.keras/datasets`.
Returns:
Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.
Expand Down Expand Up @@ -72,12 +74,16 @@ def load_data(path="mnist.npz"):
origin_folder = (
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/"
)
if cache_dir:
cache_dir = os.path.expanduser(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
path = get_file(
path,
origin=origin_folder + "mnist.npz",
file_hash=( # noqa: E501
"731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"
),
cache_dir=cache_dir,
)
with np.load(path, allow_pickle=True) as f:
x_train, y_train = f["x_train"], f["y_train"]
Expand Down
Loading

0 comments on commit 8fdeeea

Please sign in to comment.