Skip to content

Commit

Permalink
Add README for unet layers. (#63)
Browse files Browse the repository at this point in the history
* Add README for unet layers.

* update
  • Loading branch information
yichunk authored Jun 20, 2024
1 parent 0a6c9b3 commit 9b06abe
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ai_edge_torch/generative/layers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ These two files provide the following common Python helper functions:
And also the following `nn.Module` classes:
* `TransformerBlock`
* `CausalSelfAttention`
* `SelfAttention`
* `CrossAttention`

## Builder class for common layers
In `builder.py`, it provides following helper functions:
Expand Down
22 changes: 22 additions & 0 deletions ai_edge_torch/generative/layers/unet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# ODML UNet Layers
Common PyTorch building blocks to re-author UNet based models.

## Blocks 2D layers
`blocks_2d.py` provides common building blocks used in AutoEncoder and general UNet-like models. Each block has a corresponding config class provided in `model_config.py`, and the block layer is initialized with the config class. `blocks_2d.py` provide the following blocks:
* `ResidualBlock2D`: a basic residual layer containing two convolution layers, with optional time embedding layer.
* `AttentionBlock2D`: self attention layer for 2D tensor.
* `CrossAttentionBlock2D`: cross attention layer for 2D tensor, between latent tensor and context tensor.
* `FeedForwardBlock2D`: basic feed forward layer used in transformer 2D block.
* `TransformerBlock2D`: building block for text-to-image diffusion models, containing `AttentionBlock2D`, `CrossAttentionBlock2D` and `FeedForwardBlock2D`.
* `DownEncoderBlock2D`: encoder block used in AutoEncoder and UNet, with optional down sampling layer.
* `UpDecoderBlock2D`: decoder block used in AutoEncoder and UNet, with optional up sampling layer.
* `SkipUpDecoderBlock2D`: decoder block used in UNet, with skip connections from encoder.
* `MidBlock2D`: middle block used in AutoEncoder and UNet.

## Builder class for common layers:
In `builder.py`, it provides following helper functions:
* `build_upsampling`
* `build_downsampling`

## Model config class
`model_config.py` provide the configs classes used in 2D blocks, utility layers and whole AutoEncoder and UNet model.

0 comments on commit 9b06abe

Please sign in to comment.