From 697b0ad9c776ab0c885fb8f9935a9537a567784b Mon Sep 17 00:00:00 2001 From: Morgan Laco Date: Wed, 31 Jul 2019 21:35:44 -0400 Subject: [PATCH 1/2] Accept path to model or a TF IOHandler --- src/index.ts | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/index.ts b/src/index.ts index 801e41da..fc54891b 100644 --- a/src/index.ts +++ b/src/index.ts @@ -35,25 +35,36 @@ export async function load(base = BASE_PATH, options = { size: IMAGE_SIZE }) { return nsfwnet } +interface IOHandler { + load: () => any +} + export class NSFWJS { public endpoints: string[] private options: nsfwjsOptions - private path: string + private pathOrIOHandler: string private model: tf.LayersModel private intermediateModels: { [layerName: string]: tf.LayersModel } = {} private normalizationOffset: tf.Scalar - constructor(base: string, options: nsfwjsOptions) { + constructor( + modelPathBaseOrIOHandler: string | IOHandler, + options: nsfwjsOptions + ) { this.options = options - this.path = `${base}model.json` this.normalizationOffset = tf.scalar(255) + + if (typeof modelPathBaseOrIOHandler === 'string') { + this.pathOrIOHandler = `${modelPathBaseOrIOHandler}model.json` + } else { + } } async load() { // this is a Layers Model - this.model = await tf.loadLayersModel(this.path) + this.model = await tf.loadLayersModel(this.pathOrIOHandler) this.endpoints = this.model.layers.map(l => l.name) const { size } = this.options From 6841bb7d3c61991069d6d3689e8c62a0eddb2ffd Mon Sep 17 00:00:00 2001 From: Morgan Laco Date: Thu, 1 Aug 2019 10:26:45 -0400 Subject: [PATCH 2/2] Set pathOrIOHandler to IOHandler --- src/index.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/index.ts b/src/index.ts index fc54891b..51c8e8e2 100644 --- a/src/index.ts +++ b/src/index.ts @@ -43,7 +43,7 @@ export class NSFWJS { public endpoints: string[] private options: nsfwjsOptions - private pathOrIOHandler: string + private pathOrIOHandler: string | IOHandler private model: tf.LayersModel private intermediateModels: { [layerName: string]: tf.LayersModel } = {} @@ -59,6 +59,7 @@ export class NSFWJS { if (typeof modelPathBaseOrIOHandler === 'string') { this.pathOrIOHandler = `${modelPathBaseOrIOHandler}model.json` } else { + this.pathOrIOHandler = modelPathBaseOrIOHandler } }