canva-render/scripts/post_processing.py
2023-12-19 14:39:13 +08:00

174 lines
7 KiB
Python

import logging
import asyncio
import time
from collections import Counter
import cv2
import numpy as np
from azure.storage.blob.aio import BlobServiceClient, download_blob_from_url
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorCursor, AsyncIOMotorCollection, AsyncIOMotorDatabase
AZUREBLOB_SAS_URL = "https://internblob.blob.core.windows.net/v-lixinyang/?sp=racwdli&st=2023-09-17T15:37:58Z&se=2023-12-31T23:37:58Z&spr=https&sv=2022-11-02&sr=c&sig=u%2FPbZ4fNttAPeLj0NEEpX0eIgFcjhot%2Bmy3iGd%2BCmxk%3D"
CONTAINER = "canva-render-10.19"
MONGODB_URI = "mongodb://localhost:27017/canva"
class BlobAsync(object):
async def readall(self, blob):
blob_service_client = BlobServiceClient(AZUREBLOB_SAS_URL)
async with blob_service_client:
container_client = blob_service_client.get_container_client(CONTAINER)
# async for bname in container_client.list_blob_names():
# print(bname)
blob_client = container_client.get_blob_client(blob)
if not await blob_client.exists():
return None
stream = await blob_client.download_blob()
return stream.readall()
async def open_image(self, blob: str):
async with BlobServiceClient(AZUREBLOB_SAS_URL) as blob_service_client:
container_client = blob_service_client.get_container_client(CONTAINER)
# async for bname in container_client.list_blob_names():
# print(bname)
blob_client = container_client.get_blob_client(blob)
if not await blob_client.exists():
return None
stream = await blob_client.download_blob()
buf = np.frombuffer(await stream.readall(), dtype=np.uint8)
image = cv2.imdecode(buf, cv2.IMREAD_COLOR)
await blob_client.close()
await container_client.close()
return image
async def upload_image(self, blob, image):
async with BlobServiceClient(AZUREBLOB_SAS_URL) as blob_service_client:
# Instantiate a new ContainerClient
container_client = blob_service_client.get_container_client(CONTAINER)
blob_client = container_client.get_blob_client(blob)
is_success, buffer = cv2.imencode('.png', image)
await blob_client.upload_blob(data=buffer.tobytes(), overwrite=True)
await blob_client.close()
await container_client.close()
async def get_mask(img1, img2):
"""Assume img1 and img2 are exactly the same, except text areas
"""
try:
diff = cv2.absdiff(img1, img2)
except:
raise ValueError("img1 and img2 are not the same size")
mask = cv2.cvtColor(diff, cv2.COLOR_RGBA2GRAY)
thresh, binmask= cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY)
return thresh, binmask
async def filter_mask_size(mask, thresh=0.4):
non_zero_pixels = cv2.countNonZero(mask)
total_pixels = mask.shape[0] * mask.shape[1]
if non_zero_pixels > total_pixels * thresh:
return True
else:
return False
mask_filtered_count = Counter()
async def process_cdf(blob: BlobAsync, collection, cdf):
folder = cdf["rendered_folder"]
async with asyncio.TaskGroup() as g:
task1 = g.create_task(blob.open_image(f"{folder}/t=true.png"))
task2 = g.create_task(blob.open_image(f"{folder}/t=false.png"))
img1, img2 = task1.result(), task2.result()
if img1 is None and img2 is None:
mask_filtered_count["not found"] += 1
await collection.update_one({"_id": cdf["_id"]}, {"$set": {"last_mask_render": -1, "failed_reason": "not found both"}})
return
if img1 is None:
mask_filtered_count["not found"] += 1
await collection.update_one({"_id": cdf["_id"]}, {"$set": {"last_mask_render": -1, "failed_reason": "not found t=true"}})
return
if img2 is None:
mask_filtered_count["not found"] += 1
await collection.update_one({"_id": cdf["_id"]}, {"$set": {"last_mask_render": -1, "failed_reason": "not found t=false"}})
return
try:
binary_thresh, mask = await get_mask(img1, img2)
except ValueError as e:
await collection.update_one({"_id": cdf["_id"]}, {"$set": {"last_mask_render": -1, "failed_reason": "size not match"}})
mask_filtered_count["size not match"] += 1
return
mask_filters = [
(filter_mask_size, "mask too small")
]
tasks = list()
async with asyncio.TaskGroup() as g:
for f, reason in mask_filters:
tasks.append((g.create_task(f(mask)), reason))
for task, reason in tasks:
if task.result():
mask_filtered_count[reason] += 1
await collection.update_one({"_id": cdf["_id"]}, {"$set": {"last_mask_render": -1, "failed_reason": reason}, "$unset": {"last_fetched": -1}})
return
await blob.upload_image(f"{folder}/mask.png", mask)
await collection.update_one({"_id": cdf["_id"]}, {"$set": {"last_mask_render": time.time()}})
mask_filtered_count["success"] += 1
async def main():
client = AsyncIOMotorClient(MONGODB_URI)
db = client.get_database("canva")
collection = db["cdf"]
logger = logging.getLogger('azure.mgmt.resource')
logger.setLevel(logging.WARNING)
blob = BlobAsync()
cdf_cursor: AsyncIOMotorCursor = collection.find({
'$or': [
{ '$and': [
{ 'rendered_folder': { '$exists': True } },
{ 'last_fetched': { '$gt': 1697688216 } },
{ 'last_fetched': { '$lt': time.time() - 600 } },
{ 'last_mask_render': { '$exists': False }}
]},
{ '$and': [
{ 'last_fetched': {'$gt': 1697998932}},
{ 'last_mask_render': { '$not': { '$gt': 0 } } }
]}
]}, batch_size=400)
cdf_list = await cdf_cursor.to_list(length=200)
await cdf_cursor.close()
while cdf_list is not []:
async with asyncio.TaskGroup() as g:
taskset = set()
for cdf in cdf_list:
taskset.add(
g.create_task(process_cdf(blob, collection, cdf))
)
await asyncio.sleep(10)
cdf_cursor: AsyncIOMotorCursor = collection.find({
'$or': [
{ '$and': [
{ 'rendered_folder': { '$exists': True } },
{ 'last_fetched': { '$gt': 1697688216 } },
{ 'last_fetched': { '$lt': time.time() - 600 } },
{ 'last_mask_render': { '$exists': False }}
]},
{ '$and': [
{ 'last_fetched': {'$gt': 1697998932}},
{ 'last_mask_render': { '$not': { '$gt': 0 } } }
]}
]}, batch_size=400)
cdf_list = await cdf_cursor.to_list(length=200)
await cdf_cursor.close()
print(mask_filtered_count)
if __name__ == "__main__":
asyncio.run(main())