228 lines
8.5 KiB
Python
228 lines
8.5 KiB
Python
|
import argparse
|
||
|
import json
|
||
|
import logging
|
||
|
import os
|
||
|
import queue
|
||
|
import re
|
||
|
from copy import deepcopy
|
||
|
from collections import defaultdict
|
||
|
from dataclasses import asdict, dataclass
|
||
|
from io import BytesIO
|
||
|
from typing import Optional
|
||
|
|
||
|
import cv2
|
||
|
import scrapy
|
||
|
import numpy as np
|
||
|
from azure.storage.blob import BlobServiceClient
|
||
|
from scrapy.crawler import CrawlerProcess
|
||
|
from scrapy.exceptions import DropItem
|
||
|
|
||
|
from cdf_parser import CdfParser
|
||
|
|
||
|
|
||
|
def get_mask(img1: np.ndarray, img2: np.ndarray):
|
||
|
"""Assume img1 and img2 are exactly the same, except text areas
|
||
|
"""
|
||
|
try:
|
||
|
diff = cv2.absdiff(img1, img2)
|
||
|
except:
|
||
|
raise ValueError("img1 and img2 are not the same size")
|
||
|
mask = cv2.cvtColor(diff, cv2.COLOR_RGBA2GRAY)
|
||
|
thresh, binmask = cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY)
|
||
|
return thresh, binmask
|
||
|
|
||
|
@dataclass
|
||
|
class RenderedImage:
|
||
|
_id: str
|
||
|
image: bytes
|
||
|
cdf_with_metadata: dict
|
||
|
total_elements: int
|
||
|
element_idx: Optional[int]
|
||
|
|
||
|
@dataclass
|
||
|
class RenderedImageWithMask(RenderedImage):
|
||
|
mask: bytes
|
||
|
|
||
|
class QuietLogFormatter(scrapy.logformatter.LogFormatter):
|
||
|
def dropped(self, item, exception, response, spider):
|
||
|
return {
|
||
|
'level': logging.INFO, # lowering the level from logging.WARNING
|
||
|
'msg': "Dropped: %(exception)s",
|
||
|
'args': {
|
||
|
'exception': exception,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
class GroupAndFilterPipeline:
|
||
|
unpaired: dict[str, dict[int, RenderedImage]] = defaultdict(dict)
|
||
|
|
||
|
@staticmethod
|
||
|
def mask_discriminator(img1, img2, mask, thresh=0.5):
|
||
|
return True
|
||
|
non_zero_pixels = cv2.countNonZero(mask)
|
||
|
total_pixels = mask.shape[0] * mask.shape[1]
|
||
|
if non_zero_pixels > total_pixels * thresh:
|
||
|
return False
|
||
|
else:
|
||
|
return True
|
||
|
|
||
|
def process_item(self, item: RenderedImage, spider):
|
||
|
if item.element_idx is None:
|
||
|
# whole image as -1 to keep order
|
||
|
self.unpaired[item._id][-1] = item
|
||
|
else:
|
||
|
self.unpaired[item._id][item.element_idx] = item
|
||
|
|
||
|
rendered_items = self.unpaired[item._id]
|
||
|
|
||
|
if len(rendered_items) - 1 == item.total_elements:
|
||
|
# No more layers
|
||
|
for i in range(-1, item.total_elements):
|
||
|
if i not in rendered_items:
|
||
|
# Missing layer, put back to queue
|
||
|
spider.cdf_queue.put(item.cdf_with_metadata)
|
||
|
raise DropItem(f"Missing layer {i} for {item._id}, layer index must cover [-1, {item.total_elements})")
|
||
|
diff_masks = dict()
|
||
|
for i in range(item.total_elements):
|
||
|
img1 = cv2.imdecode(np.asarray(bytearray(rendered_items[i-1].image), dtype=np.uint8), cv2.IMREAD_ANYCOLOR)
|
||
|
img2 = cv2.imdecode(np.asarray(bytearray(rendered_items[i].image), dtype=np.uint8), cv2.IMREAD_ANYCOLOR)
|
||
|
try:
|
||
|
_, mask = get_mask(img1, img2)
|
||
|
except ValueError:
|
||
|
# Error in get_mask, put back to queue
|
||
|
spider.cdf_queue.put(item.cdf_with_metadata)
|
||
|
raise DropItem(f"Error in get_mask for {item._id} between {i} and {i-1}")
|
||
|
if self.mask_discriminator(img1, img2, mask):
|
||
|
diff_masks[i] = mask
|
||
|
else:
|
||
|
# Render error
|
||
|
spider.cdf_queue.put(item.cdf_with_metadata)
|
||
|
raise DropItem(f"Discriminator failed for {item._id} between {i} and {i-1}")
|
||
|
return [ rendered_items[-1] ]+ [
|
||
|
RenderedImageWithMask(
|
||
|
mask=cv2.imencode('.png', mask)[1].tobytes(),
|
||
|
**asdict(rendered_items[i])
|
||
|
)
|
||
|
for i, mask in diff_masks.items()
|
||
|
]
|
||
|
else:
|
||
|
idx = -1 if item.element_idx is None else item.element_idx
|
||
|
self.unpaired[item._id][idx] = item
|
||
|
return None
|
||
|
|
||
|
class AzureUploadPipeline:
|
||
|
|
||
|
AZUREBLOB_SAS_URL = "https://internblob.blob.core.windows.net/v-lixinyang?sp=racwdli&st=2023-10-08T06:40:24Z&se=2024-01-04T14:40:24Z&spr=https&sv=2022-11-02&sr=c&sig=77O5xpepaRehrwUJ0FSnYQDbTBJzxg629GpStnWrFg4%3D"
|
||
|
CONTAINER = "canva-render-11.30"
|
||
|
|
||
|
def process_item(self, items: Optional[list[RenderedImageWithMask]], spider):
|
||
|
if items is None:
|
||
|
return None
|
||
|
|
||
|
for item in items:
|
||
|
if item.element_idx is None:
|
||
|
filename = f"{item._id}-full.png"
|
||
|
self.client.upload_blob(name=filename, data=BytesIO(item.image), overwrite=True)
|
||
|
else:
|
||
|
filename = f"{item._id}-({item.element_idx}).png"
|
||
|
filename_mask = f"{item._id}-({item.element_idx})-mask.png"
|
||
|
self.client.upload_blob(name=filename, data=BytesIO(item.image), overwrite=True)
|
||
|
self.client.upload_blob(name=filename_mask, data=BytesIO(item.mask), overwrite=True)
|
||
|
spider.logger.info(f"Uploaded {items[0]._id} to Azure Blob Storage")
|
||
|
return items
|
||
|
|
||
|
def open_spider(self, spider):
|
||
|
# Acquire the logger for azure sdk library
|
||
|
logger = logging.getLogger('azure.mgmt.resource')
|
||
|
# Set the desired logging level
|
||
|
logger.setLevel(logging.WARNING)
|
||
|
|
||
|
blob_client = BlobServiceClient(account_url=self.AZUREBLOB_SAS_URL, logger=logger)
|
||
|
self.client = blob_client.get_container_client(self.CONTAINER)
|
||
|
|
||
|
|
||
|
class ImagesCrawler(scrapy.Spider):
|
||
|
name = "render-controller"
|
||
|
custom_settings = {
|
||
|
"LOG_LEVEL": "INFO",
|
||
|
"LOG_FORMATTER": "__main__.QuietLogFormatter",
|
||
|
"ITEM_PIPELINES": {
|
||
|
"__main__.GroupAndFilterPipeline": 300,
|
||
|
"__main__.AzureUploadPipeline": 400,
|
||
|
},
|
||
|
# "DOWNLOAD_DELAY": 1,
|
||
|
"CONCURRENT_REQUESTS": 8,
|
||
|
}
|
||
|
cdf_file = "cdfs.json"
|
||
|
cdf_queue = queue.Queue()
|
||
|
|
||
|
def start_requests(self):
|
||
|
self.logger.info(f"Reading cdf file {self.cdf_file}")
|
||
|
with open(self.cdf_file, "r") as f:
|
||
|
for line in f:
|
||
|
cdf_with_metadata = json.loads(line)
|
||
|
self.cdf_queue.put(cdf_with_metadata)
|
||
|
|
||
|
while True:
|
||
|
try:
|
||
|
cdf_with_metadata = self.cdf_queue.get(timeout=10)
|
||
|
except TimeoutError:
|
||
|
break
|
||
|
assert (match := re.match(r"^https://template.canva.com/[-_\w]{1,11}/(\d+)/(\d+)/(.*)\.cdf$", cdf_with_metadata["url"]))
|
||
|
num1, num2, tid = match.groups()
|
||
|
|
||
|
cdf_id = f"{cdf_with_metadata['template_id']}-{num1}-{num2}-{tid}"
|
||
|
|
||
|
cdf_parser = CdfParser(cdf_with_metadata["content"], cdf_with_metadata["template_id"])
|
||
|
total_elements = len(cdf_parser.get_elements())
|
||
|
|
||
|
# Render whole image with all elements
|
||
|
yield scrapy.Request(
|
||
|
f"http://localhost:8000/render",
|
||
|
method='POST',
|
||
|
body=json.dumps(cdf_with_metadata["content"]),
|
||
|
callback=self.parse,
|
||
|
cb_kwargs=dict(
|
||
|
id=cdf_id,
|
||
|
cdf_with_metadata=cdf_with_metadata,
|
||
|
total_elements=total_elements,
|
||
|
element_idx=None
|
||
|
),
|
||
|
headers={'Content-Type':'application/json'}
|
||
|
)
|
||
|
|
||
|
layered_cdf_with_metadata = deepcopy(cdf_with_metadata)
|
||
|
layered_cdf_parser = CdfParser(layered_cdf_with_metadata["content"], layered_cdf_with_metadata["template_id"])
|
||
|
|
||
|
for element_idx, layer_cdf in enumerate(layered_cdf_parser.remove_elements_iter()):
|
||
|
yield scrapy.Request(
|
||
|
f"http://localhost:8000/render",
|
||
|
method='POST',
|
||
|
body=json.dumps(layer_cdf),
|
||
|
callback=self.parse,
|
||
|
cb_kwargs=dict(
|
||
|
id=cdf_id,
|
||
|
cdf_with_metadata=cdf_with_metadata,
|
||
|
total_elements=total_elements,
|
||
|
element_idx=element_idx,
|
||
|
),
|
||
|
headers={'Content-Type':'application/json'}
|
||
|
)
|
||
|
|
||
|
def parse(self, response, id, cdf_with_metadata, total_elements, element_idx):
|
||
|
yield RenderedImage(
|
||
|
_id=id,
|
||
|
cdf_with_metadata=cdf_with_metadata,
|
||
|
image=response.body,
|
||
|
total_elements=total_elements,
|
||
|
element_idx=element_idx,
|
||
|
)
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("-f", "--cdf-file")
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
process = CrawlerProcess()
|
||
|
process.crawl(ImagesCrawler, cdf_file=args.cdf_file)
|
||
|
process.start()
|