Skip to content

Commit

Permalink
Merge pull request #7 from DarthAffe/ParameterValidation
Browse files Browse the repository at this point in the history
Added Parameter validation
  • Loading branch information
DarthAffe authored May 6, 2024
2 parents de73276 + 4afbcdb commit dc58d4d
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 1 deletion.
70 changes: 70 additions & 0 deletions StableDiffusion.NET/Extensions/ParameterExtension.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#pragma warning disable CA2208

using System;

namespace StableDiffusion.NET;

public static class ParameterExtension
{
public static void Validate(this StableDiffusionParameter parameter)
{
ArgumentNullException.ThrowIfNull(parameter, nameof(parameter));
ArgumentNullException.ThrowIfNull(parameter.ControlNet, nameof(StableDiffusionParameter.ControlNet));
ArgumentNullException.ThrowIfNull(parameter.PhotoMaker, nameof(StableDiffusionParameter.PhotoMaker));
ArgumentNullException.ThrowIfNull(parameter.NegativePrompt, nameof(StableDiffusionParameter.NegativePrompt));

ArgumentOutOfRangeException.ThrowIfNegativeOrZero(parameter.Width, nameof(StableDiffusionParameter.Width));
ArgumentOutOfRangeException.ThrowIfNegativeOrZero(parameter.Height, nameof(StableDiffusionParameter.Height));
ArgumentOutOfRangeException.ThrowIfNegativeOrZero(parameter.SampleSteps, nameof(StableDiffusionParameter.SampleSteps));

ArgumentOutOfRangeException.ThrowIfNegative(parameter.CfgScale, nameof(StableDiffusionParameter.CfgScale));
ArgumentOutOfRangeException.ThrowIfNegative(parameter.Strength, nameof(StableDiffusionParameter.Strength));

if (!Enum.IsDefined(parameter.SampleMethod)) throw new ArgumentOutOfRangeException(nameof(StableDiffusionParameter.SampleMethod));

parameter.ControlNet.Validate();
parameter.PhotoMaker.Validate();
}

public static void Validate(this StableDiffusionControlNetParameter parameter)
{
ArgumentNullException.ThrowIfNull(parameter, nameof(StableDiffusionParameter.ControlNet));

ArgumentOutOfRangeException.ThrowIfNegative(parameter.Strength, nameof(StableDiffusionControlNetParameter.Strength));
ArgumentOutOfRangeException.ThrowIfNegative(parameter.CannyHighThreshold, nameof(StableDiffusionControlNetParameter.CannyHighThreshold));
ArgumentOutOfRangeException.ThrowIfNegative(parameter.CannyLowThreshold, nameof(StableDiffusionControlNetParameter.CannyLowThreshold));
ArgumentOutOfRangeException.ThrowIfNegative(parameter.CannyWeak, nameof(StableDiffusionControlNetParameter.CannyWeak));
ArgumentOutOfRangeException.ThrowIfNegative(parameter.CannyStrong, nameof(StableDiffusionControlNetParameter.CannyStrong));
}

public static void Validate(this PhotoMakerParameter parameter)
{
ArgumentNullException.ThrowIfNull(parameter, nameof(StableDiffusionParameter.PhotoMaker));
ArgumentNullException.ThrowIfNull(parameter.InputIdImageDirectory, nameof(PhotoMakerParameter.InputIdImageDirectory));

ArgumentOutOfRangeException.ThrowIfNegative(parameter.StyleRatio, nameof(PhotoMakerParameter.StyleRatio));
}

public static void Validate(this ModelParameter parameter)
{
ArgumentNullException.ThrowIfNull(parameter, nameof(parameter));
ArgumentNullException.ThrowIfNull(parameter.TaesdPath, nameof(ModelParameter.TaesdPath));
ArgumentNullException.ThrowIfNull(parameter.LoraModelDir, nameof(ModelParameter.LoraModelDir));
ArgumentNullException.ThrowIfNull(parameter.VaePath, nameof(ModelParameter.VaePath));
ArgumentNullException.ThrowIfNull(parameter.ControlNetPath, nameof(ModelParameter.ControlNetPath));
ArgumentNullException.ThrowIfNull(parameter.EmbeddingsDirectory, nameof(ModelParameter.EmbeddingsDirectory));
ArgumentNullException.ThrowIfNull(parameter.StackedIdEmbeddingsDirectory, nameof(ModelParameter.StackedIdEmbeddingsDirectory));

if (!Enum.IsDefined(parameter.RngType)) throw new ArgumentOutOfRangeException(nameof(ModelParameter.RngType));
if (!Enum.IsDefined(parameter.Quantization)) throw new ArgumentOutOfRangeException(nameof(ModelParameter.Quantization));
if (!Enum.IsDefined(parameter.Schedule)) throw new ArgumentOutOfRangeException(nameof(ModelParameter.Schedule));
}

public static void Validate(this UpscalerModelParameter parameter)
{
ArgumentNullException.ThrowIfNull(parameter, nameof(parameter));
ArgumentNullException.ThrowIfNull(parameter.ESRGANPath, nameof(UpscalerModelParameter.ESRGANPath));

if (!Enum.IsDefined(parameter.Quantization)) throw new ArgumentOutOfRangeException(nameof(ModelParameter.Quantization));
}
}
23 changes: 22 additions & 1 deletion StableDiffusion.NET/StableDiffusionModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ static StableDiffusionModel()

public StableDiffusionModel(string modelPath, ModelParameter parameter, UpscalerModelParameter? upscalerParameter = null)
{
ArgumentException.ThrowIfNullOrWhiteSpace(modelPath, nameof(modelPath));

parameter.Validate();
upscalerParameter?.Validate();

this._modelPath = modelPath;
this._parameter = parameter;
this._upscalerParameter = upscalerParameter;
Expand Down Expand Up @@ -88,6 +93,9 @@ private void Initialize()
public StableDiffusionImage TextToImage(string prompt, StableDiffusionParameter parameter)
{
ObjectDisposedException.ThrowIf(_disposed, this);
ArgumentNullException.ThrowIfNull(prompt);

parameter.Validate();

Native.sd_image_t* result;
if (parameter.ControlNet.IsEnabled)
Expand Down Expand Up @@ -186,6 +194,9 @@ public StableDiffusionImage TextToImage(string prompt, StableDiffusionParameter
public StableDiffusionImage ImageToImage(string prompt, in ReadOnlySpan<byte> image, StableDiffusionParameter parameter)
{
ObjectDisposedException.ThrowIf(_disposed, this);
ArgumentNullException.ThrowIfNull(prompt);

parameter.Validate();

fixed (byte* imagePtr = image)
{
Expand All @@ -207,6 +218,9 @@ public StableDiffusionImage ImageToImage(string prompt, StableDiffusionImage ima
private StableDiffusionImage ImageToImage(string prompt, Native.sd_image_t image, StableDiffusionParameter parameter)
{
ObjectDisposedException.ThrowIf(_disposed, this);
ArgumentNullException.ThrowIfNull(prompt);

parameter.Validate();

Native.sd_image_t* result;
if (parameter.ControlNet.IsEnabled)
Expand Down Expand Up @@ -361,7 +375,14 @@ public void Dispose()
}

public static void Convert(string modelPath, string vaePath, Quantization quantization, string outputPath)
=> Native.convert(modelPath, vaePath, outputPath, quantization);
{
ArgumentException.ThrowIfNullOrWhiteSpace(nameof(modelPath));
ArgumentException.ThrowIfNullOrWhiteSpace(nameof(outputPath));
ArgumentNullException.ThrowIfNull(vaePath);
if (!Enum.IsDefined(quantization)) throw new ArgumentOutOfRangeException(nameof(quantization));

Native.convert(modelPath, vaePath, outputPath, quantization);
}

public static string GetSystemInfo()
{
Expand Down

0 comments on commit dc58d4d

Please sign in to comment.