canva-render/render_controller.py
2023-12-19 14:39:13 +08:00

228 lines
No EOL
8.5 KiB
Python

import argparse
import json
import logging
import os
import queue
import re
from copy import deepcopy
from collections import defaultdict
from dataclasses import asdict, dataclass
from io import BytesIO
from typing import Optional
import cv2
import scrapy
import numpy as np
from azure.storage.blob import BlobServiceClient
from scrapy.crawler import CrawlerProcess
from scrapy.exceptions import DropItem
from cdf_parser import CdfParser
def get_mask(img1: np.ndarray, img2: np.ndarray):
"""Assume img1 and img2 are exactly the same, except text areas
"""
try:
diff = cv2.absdiff(img1, img2)
except:
raise ValueError("img1 and img2 are not the same size")
mask = cv2.cvtColor(diff, cv2.COLOR_RGBA2GRAY)
thresh, binmask = cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY)
return thresh, binmask
@dataclass
class RenderedImage:
_id: str
image: bytes
cdf_with_metadata: dict
total_elements: int
element_idx: Optional[int]
@dataclass
class RenderedImageWithMask(RenderedImage):
mask: bytes
class QuietLogFormatter(scrapy.logformatter.LogFormatter):
def dropped(self, item, exception, response, spider):
return {
'level': logging.INFO, # lowering the level from logging.WARNING
'msg': "Dropped: %(exception)s",
'args': {
'exception': exception,
}
}
class GroupAndFilterPipeline:
unpaired: dict[str, dict[int, RenderedImage]] = defaultdict(dict)
@staticmethod
def mask_discriminator(img1, img2, mask, thresh=0.5):
return True
non_zero_pixels = cv2.countNonZero(mask)
total_pixels = mask.shape[0] * mask.shape[1]
if non_zero_pixels > total_pixels * thresh:
return False
else:
return True
def process_item(self, item: RenderedImage, spider):
if item.element_idx is None:
# whole image as -1 to keep order
self.unpaired[item._id][-1] = item
else:
self.unpaired[item._id][item.element_idx] = item
rendered_items = self.unpaired[item._id]
if len(rendered_items) - 1 == item.total_elements:
# No more layers
for i in range(-1, item.total_elements):
if i not in rendered_items:
# Missing layer, put back to queue
spider.cdf_queue.put(item.cdf_with_metadata)
raise DropItem(f"Missing layer {i} for {item._id}, layer index must cover [-1, {item.total_elements})")
diff_masks = dict()
for i in range(item.total_elements):
img1 = cv2.imdecode(np.asarray(bytearray(rendered_items[i-1].image), dtype=np.uint8), cv2.IMREAD_ANYCOLOR)
img2 = cv2.imdecode(np.asarray(bytearray(rendered_items[i].image), dtype=np.uint8), cv2.IMREAD_ANYCOLOR)
try:
_, mask = get_mask(img1, img2)
except ValueError:
# Error in get_mask, put back to queue
spider.cdf_queue.put(item.cdf_with_metadata)
raise DropItem(f"Error in get_mask for {item._id} between {i} and {i-1}")
if self.mask_discriminator(img1, img2, mask):
diff_masks[i] = mask
else:
# Render error
spider.cdf_queue.put(item.cdf_with_metadata)
raise DropItem(f"Discriminator failed for {item._id} between {i} and {i-1}")
return [ rendered_items[-1] ]+ [
RenderedImageWithMask(
mask=cv2.imencode('.png', mask)[1].tobytes(),
**asdict(rendered_items[i])
)
for i, mask in diff_masks.items()
]
else:
idx = -1 if item.element_idx is None else item.element_idx
self.unpaired[item._id][idx] = item
return None
class AzureUploadPipeline:
AZUREBLOB_SAS_URL = "https://internblob.blob.core.windows.net/v-lixinyang?sp=racwdli&st=2023-10-08T06:40:24Z&se=2024-01-04T14:40:24Z&spr=https&sv=2022-11-02&sr=c&sig=77O5xpepaRehrwUJ0FSnYQDbTBJzxg629GpStnWrFg4%3D"
CONTAINER = "canva-render-11.30"
def process_item(self, items: Optional[list[RenderedImageWithMask]], spider):
if items is None:
return None
for item in items:
if item.element_idx is None:
filename = f"{item._id}-full.png"
self.client.upload_blob(name=filename, data=BytesIO(item.image), overwrite=True)
else:
filename = f"{item._id}-({item.element_idx}).png"
filename_mask = f"{item._id}-({item.element_idx})-mask.png"
self.client.upload_blob(name=filename, data=BytesIO(item.image), overwrite=True)
self.client.upload_blob(name=filename_mask, data=BytesIO(item.mask), overwrite=True)
spider.logger.info(f"Uploaded {items[0]._id} to Azure Blob Storage")
return items
def open_spider(self, spider):
# Acquire the logger for azure sdk library
logger = logging.getLogger('azure.mgmt.resource')
# Set the desired logging level
logger.setLevel(logging.WARNING)
blob_client = BlobServiceClient(account_url=self.AZUREBLOB_SAS_URL, logger=logger)
self.client = blob_client.get_container_client(self.CONTAINER)
class ImagesCrawler(scrapy.Spider):
name = "render-controller"
custom_settings = {
"LOG_LEVEL": "INFO",
"LOG_FORMATTER": "__main__.QuietLogFormatter",
"ITEM_PIPELINES": {
"__main__.GroupAndFilterPipeline": 300,
"__main__.AzureUploadPipeline": 400,
},
# "DOWNLOAD_DELAY": 1,
"CONCURRENT_REQUESTS": 8,
}
cdf_file = "cdfs.json"
cdf_queue = queue.Queue()
def start_requests(self):
self.logger.info(f"Reading cdf file {self.cdf_file}")
with open(self.cdf_file, "r") as f:
for line in f:
cdf_with_metadata = json.loads(line)
self.cdf_queue.put(cdf_with_metadata)
while True:
try:
cdf_with_metadata = self.cdf_queue.get(timeout=10)
except TimeoutError:
break
assert (match := re.match(r"^https://template.canva.com/[-_\w]{1,11}/(\d+)/(\d+)/(.*)\.cdf$", cdf_with_metadata["url"]))
num1, num2, tid = match.groups()
cdf_id = f"{cdf_with_metadata['template_id']}-{num1}-{num2}-{tid}"
cdf_parser = CdfParser(cdf_with_metadata["content"], cdf_with_metadata["template_id"])
total_elements = len(cdf_parser.get_elements())
# Render whole image with all elements
yield scrapy.Request(
f"http://localhost:8000/render",
method='POST',
body=json.dumps(cdf_with_metadata["content"]),
callback=self.parse,
cb_kwargs=dict(
id=cdf_id,
cdf_with_metadata=cdf_with_metadata,
total_elements=total_elements,
element_idx=None
),
headers={'Content-Type':'application/json'}
)
layered_cdf_with_metadata = deepcopy(cdf_with_metadata)
layered_cdf_parser = CdfParser(layered_cdf_with_metadata["content"], layered_cdf_with_metadata["template_id"])
for element_idx, layer_cdf in enumerate(layered_cdf_parser.remove_elements_iter()):
yield scrapy.Request(
f"http://localhost:8000/render",
method='POST',
body=json.dumps(layer_cdf),
callback=self.parse,
cb_kwargs=dict(
id=cdf_id,
cdf_with_metadata=cdf_with_metadata,
total_elements=total_elements,
element_idx=element_idx,
),
headers={'Content-Type':'application/json'}
)
def parse(self, response, id, cdf_with_metadata, total_elements, element_idx):
yield RenderedImage(
_id=id,
cdf_with_metadata=cdf_with_metadata,
image=response.body,
total_elements=total_elements,
element_idx=element_idx,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--cdf-file")
args = parser.parse_args()
process = CrawlerProcess()
process.crawl(ImagesCrawler, cdf_file=args.cdf_file)
process.start()