Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

基于海明距离的相似图像识别 #63

Merged
merged 1 commit into from
Dec 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions aiotieba/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ async def _create_table_imghash(self) -> None:
async with conn.cursor() as cursor:
await cursor.execute(
f"CREATE TABLE IF NOT EXISTS `imghash_{self.fname}` \
(`img_hash` CHAR(16) PRIMARY KEY, `raw_hash` CHAR(40) UNIQUE NOT NULL, `permission` TINYINT NOT NULL DEFAULT 0, `note` VARCHAR(64) NOT NULL DEFAULT '', `record_time` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, \
(`img_hash` CHAR(16) PRIMARY KEY, `img_hash_uint64` BIGINT UNSIGNED UNIQUE NOT NULL, `raw_hash` CHAR(40) UNIQUE NOT NULL, `permission` TINYINT NOT NULL DEFAULT 0, `note` VARCHAR(64) NOT NULL DEFAULT '', `record_time` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, \
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个uint64确实是个好想法,比字符串索引快得多得多

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用定长blob(max_length_of_phash)或定长uint16/32/64存储bytes array不是常识?
只有前端给用户看的bytes才需要给他看bin2hex后的字符串

INDEX `permission`(permission), INDEX `record_time`(record_time))"
)

Expand All @@ -580,8 +580,8 @@ async def add_imghash(self, img_hash: str, raw_hash: str, /, permission: int = 0
async with self._pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
f"REPLACE INTO `imghash_{self.fname}` VALUES (%s,%s,%s,%s,DEFAULT)",
(img_hash, raw_hash, permission, note),
f"REPLACE INTO `imghash_{self.fname}` VALUES (%s,CONV(%s,16,10),%s,%s,%s,DEFAULT)",
(img_hash, img_hash, raw_hash, permission, note),
)
except aiomysql.Error as err:
LOG.warning(f"{err}. forum={self.fname} img_hash={img_hash}")
Expand Down Expand Up @@ -612,12 +612,13 @@ async def del_imghash(self, img_hash: str) -> bool:
LOG.info(f"Succeeded. forum={self.fname} img_hash={img_hash}")
return True

async def get_imghash(self, img_hash: str) -> int:
async def get_imghash(self, img_hash: str, hamming_distance: int=0) -> int:
"""
获取表imghash_{fname}中img_hash的封锁级别

Args:
img_hash (str): 图像的phash
hamming_distance: 最大海明距离 默认为0(图像phash完全一致)

Returns:
int: 封锁级别
Expand All @@ -626,9 +627,14 @@ async def get_imghash(self, img_hash: str) -> int:
try:
async with self._pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
f"SELECT `permission` FROM `imghash_{self.fname}` WHERE `img_hash`=%s", (img_hash,)
)
if hamming_distance > 0:
await cursor.execute(
f"SELECT `permission`, BIT_COUNT(`img_hash_uint64` ^ CONV(%s,16,10)) AS hd FROM `imghash_{self.fname}` HAVING hd <= %s ORDER BY hd ASC", (img_hash, hamming_distance)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
else:
await cursor.execute(
f"SELECT `permission` FROM `imghash_{self.fname}` WHERE `img_hash`=%s", (img_hash,)
)
except aiomysql.Error as err:
LOG.warning(f"{err}. forum={self.fname} img_hash={img_hash}")
return False
Expand Down
5 changes: 3 additions & 2 deletions aiotieba/reviewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,19 +419,20 @@ def compute_imghash(self, image: "np.ndarray") -> str:

return img_hash

async def get_imghash(self, image: "np.ndarray") -> int:
async def get_imghash(self, image: "np.ndarray", hamming_distance: int=0) -> int:
"""
获取图像的封锁级别

Args:
image (np.ndarray): 图像
hamming_distance: 最大海明距离 默认为0(图像phash完全一致)

Returns:
int: 封锁级别
"""

if img_hash := self.compute_imghash(image):
return await self.db.get_imghash(img_hash)
return await self.db.get_imghash(img_hash, hamming_distance)
return 0

async def get_imghash_full(self, image: "np.ndarray") -> Tuple[int, str]:
Expand Down