175 lines
7 KiB
Python
175 lines
7 KiB
Python
|
import logging
|
||
|
import asyncio
|
||
|
import time
|
||
|
from collections import Counter
|
||
|
|
||
|
import cv2
|
||
|
import numpy as np
|
||
|
from azure.storage.blob.aio import BlobServiceClient, download_blob_from_url
|
||
|
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorCursor, AsyncIOMotorCollection, AsyncIOMotorDatabase
|
||
|
|
||
|
AZUREBLOB_SAS_URL = "https://internblob.blob.core.windows.net/v-lixinyang/?sp=racwdli&st=2023-09-17T15:37:58Z&se=2023-12-31T23:37:58Z&spr=https&sv=2022-11-02&sr=c&sig=u%2FPbZ4fNttAPeLj0NEEpX0eIgFcjhot%2Bmy3iGd%2BCmxk%3D"
|
||
|
CONTAINER = "canva-render-10.19"
|
||
|
MONGODB_URI = "mongodb://localhost:27017/canva"
|
||
|
|
||
|
class BlobAsync(object):
|
||
|
|
||
|
async def readall(self, blob):
|
||
|
blob_service_client = BlobServiceClient(AZUREBLOB_SAS_URL)
|
||
|
async with blob_service_client:
|
||
|
container_client = blob_service_client.get_container_client(CONTAINER)
|
||
|
# async for bname in container_client.list_blob_names():
|
||
|
# print(bname)
|
||
|
blob_client = container_client.get_blob_client(blob)
|
||
|
|
||
|
if not await blob_client.exists():
|
||
|
return None
|
||
|
stream = await blob_client.download_blob()
|
||
|
|
||
|
return stream.readall()
|
||
|
|
||
|
async def open_image(self, blob: str):
|
||
|
async with BlobServiceClient(AZUREBLOB_SAS_URL) as blob_service_client:
|
||
|
container_client = blob_service_client.get_container_client(CONTAINER)
|
||
|
# async for bname in container_client.list_blob_names():
|
||
|
# print(bname)
|
||
|
blob_client = container_client.get_blob_client(blob)
|
||
|
|
||
|
if not await blob_client.exists():
|
||
|
return None
|
||
|
stream = await blob_client.download_blob()
|
||
|
|
||
|
buf = np.frombuffer(await stream.readall(), dtype=np.uint8)
|
||
|
image = cv2.imdecode(buf, cv2.IMREAD_COLOR)
|
||
|
|
||
|
await blob_client.close()
|
||
|
await container_client.close()
|
||
|
|
||
|
return image
|
||
|
|
||
|
async def upload_image(self, blob, image):
|
||
|
async with BlobServiceClient(AZUREBLOB_SAS_URL) as blob_service_client:
|
||
|
# Instantiate a new ContainerClient
|
||
|
container_client = blob_service_client.get_container_client(CONTAINER)
|
||
|
blob_client = container_client.get_blob_client(blob)
|
||
|
is_success, buffer = cv2.imencode('.png', image)
|
||
|
|
||
|
await blob_client.upload_blob(data=buffer.tobytes(), overwrite=True)
|
||
|
await blob_client.close()
|
||
|
await container_client.close()
|
||
|
|
||
|
async def get_mask(img1, img2):
|
||
|
"""Assume img1 and img2 are exactly the same, except text areas
|
||
|
"""
|
||
|
try:
|
||
|
diff = cv2.absdiff(img1, img2)
|
||
|
except:
|
||
|
raise ValueError("img1 and img2 are not the same size")
|
||
|
mask = cv2.cvtColor(diff, cv2.COLOR_RGBA2GRAY)
|
||
|
thresh, binmask= cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY)
|
||
|
return thresh, binmask
|
||
|
|
||
|
async def filter_mask_size(mask, thresh=0.4):
|
||
|
non_zero_pixels = cv2.countNonZero(mask)
|
||
|
total_pixels = mask.shape[0] * mask.shape[1]
|
||
|
if non_zero_pixels > total_pixels * thresh:
|
||
|
return True
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
mask_filtered_count = Counter()
|
||
|
async def process_cdf(blob: BlobAsync, collection, cdf):
|
||
|
folder = cdf["rendered_folder"]
|
||
|
async with asyncio.TaskGroup() as g:
|
||
|
task1 = g.create_task(blob.open_image(f"{folder}/t=true.png"))
|
||
|
task2 = g.create_task(blob.open_image(f"{folder}/t=false.png"))
|
||
|
img1, img2 = task1.result(), task2.result()
|
||
|
if img1 is None and img2 is None:
|
||
|
mask_filtered_count["not found"] += 1
|
||
|
await collection.update_one({"_id": cdf["_id"]}, {"$set": {"last_mask_render": -1, "failed_reason": "not found both"}})
|
||
|
return
|
||
|
if img1 is None:
|
||
|
mask_filtered_count["not found"] += 1
|
||
|
await collection.update_one({"_id": cdf["_id"]}, {"$set": {"last_mask_render": -1, "failed_reason": "not found t=true"}})
|
||
|
return
|
||
|
if img2 is None:
|
||
|
mask_filtered_count["not found"] += 1
|
||
|
await collection.update_one({"_id": cdf["_id"]}, {"$set": {"last_mask_render": -1, "failed_reason": "not found t=false"}})
|
||
|
return
|
||
|
|
||
|
try:
|
||
|
binary_thresh, mask = await get_mask(img1, img2)
|
||
|
except ValueError as e:
|
||
|
await collection.update_one({"_id": cdf["_id"]}, {"$set": {"last_mask_render": -1, "failed_reason": "size not match"}})
|
||
|
mask_filtered_count["size not match"] += 1
|
||
|
return
|
||
|
|
||
|
mask_filters = [
|
||
|
(filter_mask_size, "mask too small")
|
||
|
]
|
||
|
tasks = list()
|
||
|
|
||
|
async with asyncio.TaskGroup() as g:
|
||
|
for f, reason in mask_filters:
|
||
|
tasks.append((g.create_task(f(mask)), reason))
|
||
|
for task, reason in tasks:
|
||
|
if task.result():
|
||
|
mask_filtered_count[reason] += 1
|
||
|
await collection.update_one({"_id": cdf["_id"]}, {"$set": {"last_mask_render": -1, "failed_reason": reason}, "$unset": {"last_fetched": -1}})
|
||
|
return
|
||
|
await blob.upload_image(f"{folder}/mask.png", mask)
|
||
|
await collection.update_one({"_id": cdf["_id"]}, {"$set": {"last_mask_render": time.time()}})
|
||
|
mask_filtered_count["success"] += 1
|
||
|
|
||
|
async def main():
|
||
|
client = AsyncIOMotorClient(MONGODB_URI)
|
||
|
db = client.get_database("canva")
|
||
|
collection = db["cdf"]
|
||
|
|
||
|
logger = logging.getLogger('azure.mgmt.resource')
|
||
|
logger.setLevel(logging.WARNING)
|
||
|
blob = BlobAsync()
|
||
|
|
||
|
cdf_cursor: AsyncIOMotorCursor = collection.find({
|
||
|
'$or': [
|
||
|
{ '$and': [
|
||
|
{ 'rendered_folder': { '$exists': True } },
|
||
|
{ 'last_fetched': { '$gt': 1697688216 } },
|
||
|
{ 'last_fetched': { '$lt': time.time() - 600 } },
|
||
|
{ 'last_mask_render': { '$exists': False }}
|
||
|
]},
|
||
|
{ '$and': [
|
||
|
{ 'last_fetched': {'$gt': 1697998932}},
|
||
|
{ 'last_mask_render': { '$not': { '$gt': 0 } } }
|
||
|
]}
|
||
|
]}, batch_size=400)
|
||
|
cdf_list = await cdf_cursor.to_list(length=200)
|
||
|
await cdf_cursor.close()
|
||
|
while cdf_list is not []:
|
||
|
async with asyncio.TaskGroup() as g:
|
||
|
taskset = set()
|
||
|
for cdf in cdf_list:
|
||
|
taskset.add(
|
||
|
g.create_task(process_cdf(blob, collection, cdf))
|
||
|
)
|
||
|
await asyncio.sleep(10)
|
||
|
cdf_cursor: AsyncIOMotorCursor = collection.find({
|
||
|
'$or': [
|
||
|
{ '$and': [
|
||
|
{ 'rendered_folder': { '$exists': True } },
|
||
|
{ 'last_fetched': { '$gt': 1697688216 } },
|
||
|
{ 'last_fetched': { '$lt': time.time() - 600 } },
|
||
|
{ 'last_mask_render': { '$exists': False }}
|
||
|
]},
|
||
|
{ '$and': [
|
||
|
{ 'last_fetched': {'$gt': 1697998932}},
|
||
|
{ 'last_mask_render': { '$not': { '$gt': 0 } } }
|
||
|
]}
|
||
|
]}, batch_size=400)
|
||
|
cdf_list = await cdf_cursor.to_list(length=200)
|
||
|
await cdf_cursor.close()
|
||
|
print(mask_filtered_count)
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
asyncio.run(main())
|