diff --git a/aiotieba/database.py b/aiotieba/database.py index 44be34c0..a748df98 100644 --- a/aiotieba/database.py +++ b/aiotieba/database.py @@ -558,7 +558,7 @@ async def _create_table_imghash(self) -> None: async with conn.cursor() as cursor: await cursor.execute( f"CREATE TABLE IF NOT EXISTS `imghash_{self.fname}` \ - (`img_hash` CHAR(16) PRIMARY KEY, `raw_hash` CHAR(40) UNIQUE NOT NULL, `permission` TINYINT NOT NULL DEFAULT 0, `note` VARCHAR(64) NOT NULL DEFAULT '', `record_time` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, \ + (`img_hash` CHAR(16) PRIMARY KEY, `img_hash_uint64` BIGINT UNSIGNED UNIQUE NOT NULL, `raw_hash` CHAR(40) UNIQUE NOT NULL, `permission` TINYINT NOT NULL DEFAULT 0, `note` VARCHAR(64) NOT NULL DEFAULT '', `record_time` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, \ INDEX `permission`(permission), INDEX `record_time`(record_time))" ) @@ -580,8 +580,8 @@ async def add_imghash(self, img_hash: str, raw_hash: str, /, permission: int = 0 async with self._pool.acquire() as conn: async with conn.cursor() as cursor: await cursor.execute( - f"REPLACE INTO `imghash_{self.fname}` VALUES (%s,%s,%s,%s,DEFAULT)", - (img_hash, raw_hash, permission, note), + f"REPLACE INTO `imghash_{self.fname}` VALUES (%s,CONV(%s,16,10),%s,%s,%s,DEFAULT)", + (img_hash, img_hash, raw_hash, permission, note), ) except aiomysql.Error as err: LOG.warning(f"{err}. forum={self.fname} img_hash={img_hash}") @@ -612,12 +612,13 @@ async def del_imghash(self, img_hash: str) -> bool: LOG.info(f"Succeeded. forum={self.fname} img_hash={img_hash}") return True - async def get_imghash(self, img_hash: str) -> int: + async def get_imghash(self, img_hash: str, hamming_distance: int=0) -> int: """ 获取表imghash_{fname}中img_hash的封锁级别 Args: img_hash (str): 图像的phash + hamming_distance: 最大海明距离 默认为0(图像phash完全一致) Returns: int: 封锁级别 @@ -626,9 +627,14 @@ async def get_imghash(self, img_hash: str) -> int: try: async with self._pool.acquire() as conn: async with conn.cursor() as cursor: - await cursor.execute( - f"SELECT `permission` FROM `imghash_{self.fname}` WHERE `img_hash`=%s", (img_hash,) - ) + if hamming_distance > 0: + await cursor.execute( + f"SELECT `permission`, BIT_COUNT(`img_hash_uint64` ^ CONV(%s,16,10)) AS hd FROM `imghash_{self.fname}` HAVING hd <= %s ORDER BY hd ASC", (img_hash, hamming_distance) + ) + else: + await cursor.execute( + f"SELECT `permission` FROM `imghash_{self.fname}` WHERE `img_hash`=%s", (img_hash,) + ) except aiomysql.Error as err: LOG.warning(f"{err}. forum={self.fname} img_hash={img_hash}") return False diff --git a/aiotieba/reviewer.py b/aiotieba/reviewer.py index 81984a60..6a839358 100644 --- a/aiotieba/reviewer.py +++ b/aiotieba/reviewer.py @@ -419,19 +419,20 @@ def compute_imghash(self, image: "np.ndarray") -> str: return img_hash - async def get_imghash(self, image: "np.ndarray") -> int: + async def get_imghash(self, image: "np.ndarray", hamming_distance: int=0) -> int: """ 获取图像的封锁级别 Args: image (np.ndarray): 图像 + hamming_distance: 最大海明距离 默认为0(图像phash完全一致) Returns: int: 封锁级别 """ if img_hash := self.compute_imghash(image): - return await self.db.get_imghash(img_hash) + return await self.db.get_imghash(img_hash, hamming_distance) return 0 async def get_imghash_full(self, image: "np.ndarray") -> Tuple[int, str]: