Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for ExtractImagePatches #2188

Merged
merged 3 commits into from
Dec 24, 2024

Conversation

nanoskript
Copy link
Contributor

Closes: #436

This rewrite is based on this comment: #436 (comment) with changes to make it more general and translatable into tf2onnx.

Equivalent TensorFlow function and automated test script (expand)
import tensorflow as tf
import numpy as np
from hypothesis import given, strategies as st, settings, assume


def our_extract_image_patches(sizes, strides, rates, padding):
    # TensorFlow's constraints.
    assert sizes[0] == 1 and sizes[3] == 1
    assert strides[0] == 1 and strides[3] == 1
    assert rates[0] == 1 and rates[3] == 1
    assert padding in ["SAME", "VALID"]

    # Extract size.
    [_, size_rows, size_cols, _] = sizes

    @tf.function
    def function(tensor):
        # Input shape of [N, H, W, C].
        tensor_shape = tensor.shape

        # Transpose and reshape to [N * C, H, W, 1].
        tensor = tf.transpose(tensor, perm=[0, 3, 1, 2])
        tensor = tf.reshape(tensor, [
            tensor_shape[0] * tensor_shape[3],
            tensor_shape[1],
            tensor_shape[2],
            1,
        ])

        # Convolve with identity kernel into [N * C, ?H, ?W, K].
        k = size_rows * size_cols
        kernel = tf.reshape(tf.eye(k), [size_rows, size_cols, 1, k])
        convolution = tf.nn.conv2d(tensor, kernel, strides=strides, padding=padding, dilations=rates)

        # Reshape into [N, C, ?H, ?W, K].
        reshaped = tf.reshape(convolution, [
            tensor_shape[0],
            tensor_shape[3],
            convolution.shape[1],
            convolution.shape[2],
            k,
        ])

        # Transpose and reshape into [N, ?H, ?W, C * K].
        patches = tf.transpose(reshaped, perm=[0, 2, 3, 4, 1])
        return tf.reshape(patches, [
            tensor_shape[0],
            convolution.shape[1],
            convolution.shape[2],
            tensor_shape[3] * k,
        ])

    return function


def tf_extract_image_patches(sizes, strides, rates, padding):
    @tf.function
    def function(tensor):
        return tf.image.extract_patches(
            tensor,
            sizes=sizes,
            strides=strides,
            rates=rates,
            padding=padding,
        )

    return function


@settings(max_examples=5000)
@given(
    st.lists(st.integers(min_value=1, max_value=20), min_size=4, max_size=4),
    st.integers(min_value=1, max_value=20),
    st.integers(min_value=1, max_value=20),
    st.integers(min_value=1, max_value=20),
    st.integers(min_value=1, max_value=20),
    st.integers(min_value=1, max_value=20),
    st.integers(min_value=1, max_value=20),
    st.sampled_from(["VALID", "SAME"]),
)
def test_equal(shape, size_rows, size_cols, stride_rows, stride_cols, dil_rows, dil_cols, padding):
    sizes = [1, size_rows, size_cols, 1]
    strides = [1, stride_rows, stride_cols, 1]
    rates = [1, dil_rows, dil_cols, 1]

    try:
        tensor = tf.cast(tf.reshape(tf.range(np.prod(shape)), shape), dtype=tf.float32)
        tfs = tf_extract_image_patches(sizes, strides, rates, padding)(tensor)

        if 0 in tfs.shape:
            # We cannot handle operations that produce empty outputs.
            assume(False)
    except ValueError:
        # Ignore input if TensorFlow would fail.
        assume(False)
        return

    ours = our_extract_image_patches(sizes, strides, rates, padding)(tensor)
    assert tf.reduce_all(tf.math.equal(tfs, ours)).numpy()

Output from pytest convolve.py --hypothesis-show-statistics (no failures):

convolve.py::test_equal:

  - during generate phase (70.74 seconds):
    - Typical runtimes: ~ 1-14 ms, of which < 1ms in data generation
    - 5000 passing examples, 0 failing examples, 3913 invalid examples

  - Stopped because settings.max_examples=5000

@nanoskript nanoskript force-pushed the add-extract-image-patches branch 2 times, most recently from b4e1d24 to 2da669d Compare June 16, 2023 12:58
@nanoskript nanoskript marked this pull request as ready for review June 16, 2023 13:02
@nanoskript nanoskript force-pushed the add-extract-image-patches branch from 2da669d to 3aa08e8 Compare July 7, 2023 04:40
@fatcat-z
Copy link
Collaborator

fatcat-z commented Jul 30, 2023

Thanks you for putting that solution into this PR, and it looks great!

Rewriter is designed to rewrite the ONNX graph after we transform each tf op into the corresponding onnx op. Each rewriter will search the ONNX graph following a given pattern. Once the pattern is matched, those involved onnx ops will be replaced with some other ops for an optimization in further inference.

In this case, ExtractImagePatches is just a tf op which is not supported by tf2onnx yet. So, your implementations should be put into nn.py file instead of adding a rewriter. Please add it into nn.py, just like adding a new tf op support.

Please feel free to refer to this comment for more details.

@nanoskript
Copy link
Contributor Author

Hi @fatcat-z,

Rewriter is designed to rewrite the ONNX graph after we transform each tf op into the corresponding onnx op.

I'm not entirely sure if this is true. From my understanding, the rewriters are ran before each operation is converted into an ONNX operation:

run_rewriters(g, rewriters, continue_on_error)

where line 622 performs the conversion (?). There do appear to be late rewriters that run after the mapping occurs, but in general, it seems like the rewriting and optimization steps are separate.

I chose to implement this as a rewrite in order to avoid duplicating the construction of the Conv2D node but if you would still prefer for this to be implemented in nn.py, please let me know.

@fatcat-z
Copy link
Collaborator

fatcat-z commented Aug 2, 2023

Hi @fatcat-z,

Rewriter is designed to rewrite the ONNX graph after we transform each tf op into the corresponding onnx op.

I'm not entirely sure if this is true. From my understanding, the rewriters are ran before each operation is converted into an ONNX operation:

run_rewriters(g, rewriters, continue_on_error)

where line 622 performs the conversion (?). There do appear to be late rewriters that run after the mapping occurs, but in general, it seems like the rewriting and optimization steps are separate.
I chose to implement this as a rewrite in order to avoid duplicating the construction of the Conv2D node but if you would still prefer for this to be implemented in nn.py, please let me know.

No, graphs_from_tf() function will transfer the tf graph to onnx graph meaning each tf op has been converted to onnx op, if possible. Afterwards, process_parsed_graph() will be called to finish those rewriters and optimizations.

Yes, please implement this as an op in nn.py instead of creating a new rewriter. Thanks.

@ruihu102
Copy link

Is this new operator going to be merged into main?

@nanoskript
Copy link
Contributor Author

Is this new operator going to be merged into main?

I don't think this operator can be considered to be new but I'm aiming to get the requested changes done sometime within the week.

@nanoskript nanoskript force-pushed the add-extract-image-patches branch 2 times, most recently from 7c7cc34 to 16643fd Compare July 9, 2024 06:11
@nanoskript
Copy link
Contributor Author

Sorry for the delay! I've implemented this operation inside of nn.py instead of as a rewriter. Let me know if anything else needs to be changed!

@nanoskript nanoskript force-pushed the add-extract-image-patches branch from 16643fd to 36f203b Compare December 15, 2024 04:34
@nanoskript
Copy link
Contributor Author

@fatcat-z Hi! Apologies for the ping! Do you think this PR could be merged into main at some time?

rates = node.get_attr_value("rates")
padding = node.get_attr_str("padding")

# Our constraints.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please provide more details about this constraint so people know how to improve in future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've expanded this comment and given an example of a call that succeeds in Tensorflow but fails for this particular implementation.

Copy link
Collaborator

@fatcat-z fatcat-z left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@fatcat-z fatcat-z merged commit 6298b26 into onnx:main Dec 24, 2024
42 checks passed
@nanoskript nanoskript deleted the add-extract-image-patches branch December 24, 2024 13:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ValueError: tensorflow op ExtractImagePatches is not supported
3 participants