From b3e7107168134b820f940004bae4a0da13fb2b83 Mon Sep 17 00:00:00 2001 From: StevenLiuWen Date: Fri, 8 Mar 2024 15:37:17 +0800 Subject: [PATCH] version 1.0.0 --- README.md | 103 ++-- cli_chat.py | 194 +++++++ deepseek_vlm/__init__.py | 0 deepseek_vlm/models/__init__.py | 5 + deepseek_vlm/models/clip_encoder.py | 213 +++++++ deepseek_vlm/models/image_processing_vlm.py | 163 ++++++ deepseek_vlm/models/modeling_vlm.py | 150 +++++ deepseek_vlm/models/processing_vlm.py | 351 ++++++++++++ deepseek_vlm/models/projector.py | 80 +++ deepseek_vlm/models/sam.py | 562 ++++++++++++++++++ deepseek_vlm/models/siglip_vit.py | 605 ++++++++++++++++++++ deepseek_vlm/utils/__init__.py | 0 deepseek_vlm/utils/conversation.py | 326 +++++++++++ deepseek_vlm/utils/io.py | 55 ++ images/latex_01.jpg | Bin 0 -> 50179 bytes images/monday.jpg | Bin 0 -> 48983 bytes images/training_pipelines.png | Bin 0 -> 103185 bytes inference.py | 53 ++ requirements.txt | 11 +- setup.py | 17 + 20 files changed, 2842 insertions(+), 46 deletions(-) create mode 100644 cli_chat.py create mode 100644 deepseek_vlm/__init__.py create mode 100644 deepseek_vlm/models/__init__.py create mode 100644 deepseek_vlm/models/clip_encoder.py create mode 100644 deepseek_vlm/models/image_processing_vlm.py create mode 100644 deepseek_vlm/models/modeling_vlm.py create mode 100644 deepseek_vlm/models/processing_vlm.py create mode 100644 deepseek_vlm/models/projector.py create mode 100644 deepseek_vlm/models/sam.py create mode 100644 deepseek_vlm/models/siglip_vit.py create mode 100644 deepseek_vlm/utils/__init__.py create mode 100644 deepseek_vlm/utils/conversation.py create mode 100644 deepseek_vlm/utils/io.py create mode 100644 images/latex_01.jpg create mode 100644 images/monday.jpg create mode 100644 images/training_pipelines.png create mode 100644 inference.py create mode 100644 setup.py diff --git a/README.md b/README.md index dfd1a95..2140680 100644 --- a/README.md +++ b/README.md @@ -72,16 +72,21 @@ Introducing DeepSeek VL, ## 2. Model Downloads -We release the DeepSeek LLM 7B/67B, including both base and chat models, to the public. To support a broader and more diverse range of research within both academic and commercial communities, we are providing access to the intermediate checkpoints of the base model from its training process. Please **note** that the use of this model is subject to the terms outlined in [License section](#8-license). Commercial usage is permitted under these terms. +We release the DeepSeek-VL family, including 1.3B-base, 1.3B-chat, 7b-base and 7b-chat models, to the public. +To support a broader and more diverse range of research within both academic and commercial communities. +Please note that the use of this model is subject to the terms outlined in [License section](#8-license). Commercial usage is +permitted under these terms. ### Huggingface -| Model | Sequence Length | Download | -|:---------------------:|:---------------:|:-----------------------------------------------------------------------:| -| DeepSeek VL 7B Base | 4096 | 🤗 [HuggingFace](https://huggingface.co/deepseek-ai/deepseek-llm-7b-base) | -| DeepSeek VL 7B Chat | 4096 | 🤗 [HuggingFace](https://huggingface.co/deepseek-ai/deepseek-llm-7b-chat) | +| Model | Sequence Length | Download | +|-----------------------|-----------------|-----------------------------------------------------------------------------| +| DeepSeek-VL-1.3B-base | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/deepseek-vl-1.3b-base) | +| DeepSeek-VL-1.3B-chat | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/deepseek-vl-1.3b-chat) | +| DeepSeek-VL-7B-base | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/deepseek-vl-7b-base) | +| DeepSeek-VL-7B-chat | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/deepseek-vl-7b-chat) | -## 3. Evaluation Results +# 3. Evaluation Results ### Base Model @@ -139,53 +144,73 @@ We release the training loss curve and several benchmark metrics curves, as deta On the basis of `Python >= 3.8` environment, install the necessary dependencies by running the following command: ```shell -pip install -r requirements.txt +pip install -r requirements.txt -e . ``` ### Inference with Huggingface's Transformers You can directly employ [Huggingface's Transformers](https://github.com/huggingface/transformers) for model inference. -**Text Completion** +**Simple Inference Example** ```python import torch -from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig +from transformers import AutoModelForCausalLM -model_name = "deepseek-ai/deepseek-llm-67b-base" -tokenizer = AutoTokenizer.from_pretrained(model_name) -model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") -model.generation_config = GenerationConfig.from_pretrained(model_name) -model.generation_config.pad_token_id = model.generation_config.eos_token_id +from deepseek_vlm.models import VLChatProcessor, MultiModalityCausalLM +from deepseek_vlm.utils.io import load_pil_images -text = "An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is" -inputs = tokenizer(text, return_tensors="pt") -outputs = model.generate(**inputs.to(model.device), max_new_tokens=100) -result = tokenizer.decode(outputs[0], skip_special_tokens=True) -print(result) +# specify the path to the model +model_path = "deepseek-ai/deepseek-vl-7b-chat" +vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) +tokenizer = vl_chat_processor.tokenizer + +vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) +vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval() + +conversation = [ + { + "role": "User", + "content": "Describe each stage of this image.", + "images": ["./images/training_pipelines.png"] + }, + { + "role": "Assistant", + "content": "" + } +] + +# load images and prepare for inputs +pil_images = load_pil_images(conversation) +prepare_inputs = vl_chat_processor( + conversations=conversation, + images=pil_images, + force_batchify=True +).to(vl_gpt.device) + +# run image encoder to get the image embeddings +inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) + +# run the model to get the response +outputs = vl_gpt.language_model.generate( + inputs_embeds=inputs_embeds, + attention_mask=prepare_inputs.attention_mask, + pad_token_id=tokenizer.eos_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=512, + do_sample=False, + use_cache=True +) + +answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) +print(f"{prepare_inputs['sft_format'][0]}", answer) ``` -**Chat Completion** - -```python -import torch -from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig - -model_name = "deepseek-ai/deepseek-llm-67b-chat" -tokenizer = AutoTokenizer.from_pretrained(model_name) -model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") -model.generation_config = GenerationConfig.from_pretrained(model_name) -model.generation_config.pad_token_id = model.generation_config.eos_token_id - -messages = [ - {"role": "user", "content": "Who are you?"} -] -input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt") -outputs = model.generate(input_tensor.to(model.device), max_new_tokens=100) - -result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True) -print(result) +**CLI Chat** +```bash +python cli_chat.py --model_path deepseek-ai/deepseek-vl-7b-chat ``` Avoiding the use of the provided function `apply_chat_template`, you can also interact with our model following the sample template. Note that `messages` should be replaced by your input. diff --git a/cli_chat.py b/cli_chat.py new file mode 100644 index 0000000..f2bbfee --- /dev/null +++ b/cli_chat.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- + +import argparse +import os +import sys +from PIL import Image +import readline +from threading import Thread +import torch +from transformers import TextIteratorStreamer +from deepseek_vlm.utils.io import load_pretrained_model + + +def load_image(image_file): + image = Image.open(image_file).convert("RGB") + return image + + +def get_help_message(image_token): + help_msg = ( + f"\t\t DeepSeek-VL-Chat is a chatbot that can answer questions based on the given image. Enjoy it! \n" + f"Usage: \n" + f" 1. type `exit` to quit. \n" + f" 2. type `{image_token}` to indicate there is an image. You can enter multiple images, " + f"e.g '{image_token} is a dot, {image_token} is a cat, and what is it in {image_token}?'. " + f"When you type `{image_token}`, the chatbot will ask you to input image file path. \n" + f" 4. type `help` to get the help messages. \n" + f" 5. type `new` to start a new conversation. \n" + f" Here is an example, you can type: 'Describe the image.'\n" + ) + + return help_msg + + +@torch.inference_mode() +def response(args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config): + + prompt = conv.get_prompt() + prepare_inputs = vl_chat_processor.__call__( + prompt=prompt, + images=pil_images, + force_batchify=True + ).to(vl_gpt.device) + + # run image encoder to get the image embeddings + inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) + + streamer = TextIteratorStreamer( + tokenizer=tokenizer, + skip_prompt=True, + skip_special_tokens=True + ) + generation_config["inputs_embeds"] = inputs_embeds + generation_config["attention_mask"] = prepare_inputs.attention_mask + generation_config["streamer"] = streamer + + thread = Thread(target=vl_gpt.language_model.generate, kwargs=generation_config) + thread.start() + + yield from streamer + + +def get_user_input(hint: str): + user_input = "" + while user_input == "": + try: + user_input = input(f"{hint}") + except KeyboardInterrupt: + print() + continue + except EOFError: + user_input = "exit" + + return user_input + + +def chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config): + image_token = vl_chat_processor.image_token + help_msg = get_help_message(image_token) + + while True: + + print(help_msg) + + pil_images = [] + conv = vl_chat_processor.new_chat_template() + roles = conv.roles + + while True: + + # get user input + user_input = get_user_input(f"{roles[0]} [{image_token} indicates an image]: ") + + if user_input == "exit": + print("Chat program exited.") + sys.exit(0) + + elif user_input == "help": + print(help_msg) + + elif user_input == "new": + os.system("clear") + pil_images = [] + conv = vl_chat_processor.new_chat_template() + torch.cuda.empty_cache() + print("New conversation started.") + + else: + conv.append_message(conv.roles[0], user_input) + conv.append_message(conv.roles[1], None) + + # check if the user input is an image token + num_images = user_input.count(image_token) + cur_img_idx = 0 + + while cur_img_idx < num_images: + try: + image_file = input(f"({cur_img_idx + 1}/{num_images}) Input the image file path: ") + + except KeyboardInterrupt: + print() + continue + + except EOFError: + image_file = None + + if image_file and os.path.exists(image_file): + pil_image = load_image(image_file) + pil_images.append(pil_image) + cur_img_idx += 1 + + elif image_file == "exit": + print("Chat program exited.") + sys.exit(0) + + else: + print(f"File error, `{image_file}` does not exist. Please input the correct file path.") + + # get the answer by the model's prediction + answer = "" + answer_iter = response(args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config) + sys.stdout.write(f"{conv.roles[1]}: ") + for char in answer_iter: + answer += char + sys.stdout.write(char) + sys.stdout.flush() + + sys.stdout.write("\n") + sys.stdout.flush() + conv.messages[-1][-1] = answer + + +def main(args): + + # setup + tokenizer, vl_chat_processor, vl_gpt = load_pretrained_model(args.model_path) + generation_config = dict( + pad_token_id=vl_chat_processor.tokenizer.eos_token_id, + bos_token_id=vl_chat_processor.tokenizer.bos_token_id, + eos_token_id=vl_chat_processor.tokenizer.eos_token_id, + max_new_tokens=args.max_gen_len, + use_cache=True, + ) + if args.temperature > 0: + generation_config.update({ + "do_sample": True, + "top_p": args.top_p, + "temperature": args.temperature, + "repetition_penalty": args.repetition_penalty, + }) + else: + generation_config.update({"do_sample": False}) + + chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default="deepseek-ai/deepseek-vl-7b-chat", + help="the huggingface model name or the local path of the downloaded huggingface model.") + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--top_p", type=float, default=0.95) + parser.add_argument("--repetition_penalty", type=float, default=1.1) + parser.add_argument("--max_gen_len", type=int, default=512) + args = parser.parse_args() + main(args) + +""" + +CUDA_VISIBLE_DEVICES=4 python cli.py --model_path "/home/liuwen/3fs_shared/ckpts/deepseek-vl-7b-chat" + +CUDA_VISIBLE_DEVICES=2 python cli.py --model_path "/home/liuwen/3fs_shared/ckpts/siglip_1B_HF" + +""" diff --git a/deepseek_vlm/__init__.py b/deepseek_vlm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deepseek_vlm/models/__init__.py b/deepseek_vlm/models/__init__.py new file mode 100644 index 0000000..2ef8b03 --- /dev/null +++ b/deepseek_vlm/models/__init__.py @@ -0,0 +1,5 @@ + +from .image_processing_vlm import VLMImageProcessor +from .processing_vlm import VLChatProcessor +from .modeling_vlm import MultiModalityCausalLM + diff --git a/deepseek_vlm/models/clip_encoder.py b/deepseek_vlm/models/clip_encoder.py new file mode 100644 index 0000000..0c3c9d0 --- /dev/null +++ b/deepseek_vlm/models/clip_encoder.py @@ -0,0 +1,213 @@ +from typing import Tuple, Union, List, Dict, Optional, Literal + +import torch +import torch.nn as nn +import torchvision.transforms +from einops import rearrange + +from deepseek_vlm.models.siglip_vit import create_siglip_vit +from deepseek_vlm.models.sam import create_sam_vit + + +class CLIPVisionTower(nn.Module): + + def __init__(self, + model_name: str = "siglip_large_patch16_384", + image_size: Union[Tuple[int, int], int] = 336, + select_feature: str = "patch", + select_layer: int = -2, + select_layers: list = None, + ckpt_path: str = "", + pixel_mean: Optional[List[float]] = None, + pixel_std: Optional[List[float]] = None, + **kwargs): + super().__init__() + + self.model_name = model_name + self.select_feature = select_feature + self.select_layer = select_layer + self.select_layers = select_layers + + vision_tower_params = { + "model_name": model_name, + "image_size": image_size, + "ckpt_path": ckpt_path, + "select_layer": select_layer + } + vision_tower_params.update(kwargs) + self.vision_tower, self.forward_kwargs = self.build_vision_tower(vision_tower_params) + + if pixel_mean is not None and pixel_std is not None: + image_norm = torchvision.transforms.Normalize(mean=pixel_mean, std=pixel_std) + else: + image_norm = None + + self.image_norm = image_norm + + def build_vision_tower(self, vision_tower_params): + + if self.model_name.startswith("siglip"): + self.select_feature = "same" + vision_tower = create_siglip_vit(**vision_tower_params) + forward_kwargs = dict() + + elif self.model_name.startswith("sam"): + vision_tower = create_sam_vit(**vision_tower_params) + forward_kwargs = dict() + + else: # huggingface + from transformers import CLIPVisionModel + vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params) + forward_kwargs = dict(output_hidden_states=True) + + return vision_tower, forward_kwargs + + def feature_select(self, image_forward_outs): + + if isinstance(image_forward_outs, torch.Tensor): + # the output has been the self.select_layer"s features + image_features = image_forward_outs + else: + image_features = image_forward_outs.hidden_states[self.select_layer] + + if self.select_feature == "patch": + # if the output has cls_token + image_features = image_features[:, 1:] + elif self.select_feature == "cls_patch": + image_features = image_features + elif self.select_feature == "same": + image_features = image_features + + else: + raise ValueError(f"Unexpected select feature: {self.select_feature}") + return image_features + + def forward(self, images): + """ + + Args: + images (torch.Tensor): [b, 3, H, W] + + Returns: + image_features (torch.Tensor): [b, n_patch, d] + """ + + if self.image_norm is not None: + images = self.image_norm(images) + + image_forward_outs = self.vision_tower(images, **self.forward_kwargs) + image_features = self.feature_select(image_forward_outs) + return image_features + + +class HybridVisionTower(nn.Module): + + def __init__(self, + high_res_cfg: Dict, + low_res_cfg: Dict, + freeze_high: bool = False, + freeze_low: bool = False, + concat_type: Literal["feature", "sequence", "add", "tuple"] = "tuple", + **ignore_kwargs): + + super().__init__() + + self.vision_tower_high = CLIPVisionTower(**high_res_cfg) + self.vision_tower_low = CLIPVisionTower(**low_res_cfg) + self.low_res_size = low_res_cfg["image_size"] + self.concat_type = concat_type + + self.high_layer_norm = nn.LayerNorm(high_res_cfg.get("output_dim", 1024)) + self.low_layer_norm = nn.LayerNorm(low_res_cfg.get("output_dim", 1024)) + + if freeze_high: + for p_name, p in self.vision_tower_high.named_parameters(): + p.requires_grad = False + self.vision_tower_high = self.vision_tower_high.eval() + else: + # train donwsamples and neck + for p_name, p in self.vision_tower_high.named_parameters(): + if "downsamples" in p_name or "neck" in p_name: + p.requires_grad = True + else: + p.requires_grad = False + + if freeze_low: + for p in self.vision_tower_low.parameters(): + p.requires_grad = False + self.vision_tower_low = self.vision_tower_low.eval() + + self.resize = torchvision.transforms.Resize(self.low_res_size, antialias=True) + + def forward(self, images: torch.Tensor): + """ + + Args: + images (torch.Tensor): [bs, 3, H, W] + + Returns: + res (torch.Tensor): [bs, t, c] + """ + + # [bs, c, h, w] + high_images = images + + # [bs, c, h_low, w_low] + low_images = self.resize(images) + + # separately run two vision towers + # run high_res vision tower + high_res = self.vision_tower_high(high_images) + # [bs, c, h, w] -> [bs, h*w, c] + high_res = rearrange(high_res, "b c h w -> b (h w) c") + # run low_res vision tower + low_res = self.vision_tower_low(low_images) + + if self.concat_type == "feature": + images_features = torch.cat([high_res, low_res], dim=-1) + elif self.concat_type == "sequence": + images_features = torch.cat([high_res, low_res], dim=1) + elif self.concat_type == "add": + images_features = high_res + low_res + elif self.concat_type == "tuple": + images_features = (high_res, low_res) + + else: + raise ValueError(f"Currently only support `feature`, `sequence`, `add` and `tuple` concat type.") + + return images_features + + +if __name__ == "__main__": + image_size = 1024 + x = torch.zeros(2, 3, image_size, image_size).bfloat16().cuda() + + high_res_cfg = dict( + model_name="sam_b_downsample", + select_feature="same", + image_size=image_size, + pixel_mean=(0.48145466, 0.4578275, 0.40821073), + pixel_std=(0.26862954, 0.26130258, 0.27577711), + select_layer=-1, + ckpt_path="" + ) + + low_res_cfg = dict( + model_name="siglip_large_patch16_384", + select_feature="same", + image_size=384, + pixel_mean=(0.5, 0.5, 0.5), + pixel_std=(0.5, 0.5, 0.5), + select_layer=-1, + ckpt_path="" + ) + + net = HybridVisionTower( + high_res_cfg=high_res_cfg, + low_res_cfg=low_res_cfg, + freeze_high=True, + freeze_low=True, + concat_type="tuple" + ).bfloat16().cuda() + high_x, low_x = net(x) + print(x.shape, high_x.shape, low_x.shape) diff --git a/deepseek_vlm/models/image_processing_vlm.py b/deepseek_vlm/models/image_processing_vlm.py new file mode 100644 index 0000000..984117e --- /dev/null +++ b/deepseek_vlm/models/image_processing_vlm.py @@ -0,0 +1,163 @@ +from PIL import Image +import numpy as np +import torch +import torchvision +import torchvision.transforms.functional +from typing import List, Union, Tuple +from transformers import PretrainedConfig, AutoImageProcessor +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature +from transformers.image_utils import to_numpy_array +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +ImageType = Union[np.ndarray, torch.Tensor, Image.Image] +IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073) +IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711) +IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) +IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +class VLMImageProcessorConfig(PretrainedConfig): + + model_type = "deepseek_vlm" + image_size: int + min_size: int + image_mean: Union[Tuple[float, float, float], List[float]] + image_std: Union[Tuple[float, float, float], List[float]] + rescale_factor: float + do_normalize: bool + + def __init__( + self, + image_size: int, + min_size: int = 14, + image_mean: Union[Tuple[float, float, float], List[float]] = (0.48145466, 0.4578275, 0.40821073), + image_std: Union[Tuple[float, float, float], List[float]] = (0.26862954, 0.26130258, 0.27577711), + rescale_factor: float = 1.0 / 255.0, + do_normalize: bool = True, **kwargs + ): + + self.image_size = image_size + self.min_size = min_size + self.image_mean = image_mean + self.image_std = image_std + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + + super().__init__(**kwargs) + + +class VLMImageProcessor(BaseImageProcessor): + model_input_names = ["pixel_values"] + + def __init__( + self, + image_size: int, + min_size: int = 14, + image_mean: Union[Tuple[float, float, float], List[float]] = (0.48145466, 0.4578275, 0.40821073), + image_std: Union[Tuple[float, float, float], List[float]] = (0.26862954, 0.26130258, 0.27577711), + rescale_factor: float = 1.0 / 255.0, + do_normalize: bool = True, **kwargs + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.rescale_factor = rescale_factor + self.image_mean = image_mean + self.image_std = image_std + self.min_size = min_size + self.do_normalize = do_normalize + + if image_mean is None: + self.background_color = (127, 127, 127) + else: + self.background_color = tuple([int(x * 255) for x in image_mean]) + + def resize(self, pil_img: Image) -> np.ndarray: + """ + + Args: + pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB + + Returns: + x (np.ndarray): [3, self.image_size, self.image_size] + """ + + width, height = pil_img.size + max_size = max(width, height) + + size = [ + max(int(height / max_size * self.image_size), self.min_size), + max(int(width / max_size * self.image_size), self.min_size) + ] + + if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0: + print(f"orig size = {pil_img.size}, new size = {size}") + raise ValueError("Invalid size!") + + pil_img = torchvision.transforms.functional.resize( + pil_img, size, interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC, antialias=True + ) + + pil_img = expand2square(pil_img, self.background_color) + x = to_numpy_array(pil_img) + + # [H, W, 3] -> [3, H, W] + x = np.transpose(x, (2, 0, 1)) + + return x + + def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature: + + # resize and pad to [self.image_size, self.image_size] + # then convert from [H, W, 3] to [3, H, W] + images: List[np.ndarray] = [self.resize(image) for image in images] + + # resacle from [0, 255] -> [0, 1] + images = [ + self.rescale(image=image, scale=self.rescale_factor, input_data_format="channels_first") + for image in images + ] + + # normalize + if self.do_normalize: + images = [ + self.normalize(image=image, mean=self.image_mean, std=self.image_std, + input_data_format="channels_first") + for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + @property + def default_shape(self): + return [3, self.image_size, self.image_size] + + +# AutoConfig.register("deepseek_vlm", VLMImageProcessorConfig) +AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor) + + +if __name__ == "__main__": + image_processor = VLMImageProcessor( + image_size=1024, + image_mean=IMAGENET_INCEPTION_MEAN, + image_std=IMAGENET_INCEPTION_STD, + do_normalize=True + ) diff --git a/deepseek_vlm/models/modeling_vlm.py b/deepseek_vlm/models/modeling_vlm.py new file mode 100644 index 0000000..df56d44 --- /dev/null +++ b/deepseek_vlm/models/modeling_vlm.py @@ -0,0 +1,150 @@ +from attrdict import AttrDict +from einops import rearrange +import torch +from transformers.configuration_utils import PretrainedConfig +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + PreTrainedModel, + LlamaConfig, + LlamaForCausalLM +) + +from deepseek_vlm.models.projector import MlpProjector +from deepseek_vlm.models.clip_encoder import CLIPVisionTower, HybridVisionTower + + +def model_name_to_cls(cls_name): + + if "MlpProjector" in cls_name: + cls = MlpProjector + + elif "CLIPVisionTower" in cls_name: + cls = CLIPVisionTower + + elif "HybridVisionTower" in cls_name: + cls = HybridVisionTower + + else: + raise ValueError(f"class_name {cls_name} is invalid.") + + return cls + + +class VisionConfig(PretrainedConfig): + model_type = "vision" + cls: str = "" + params: AttrDict = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = AttrDict(kwargs.get("params", {})) + + +class AlignerConfig(PretrainedConfig): + model_type = "aligner" + cls: str = "" + params: AttrDict = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = AttrDict(kwargs.get("params", {})) + + +class MultiModalityConfig(PretrainedConfig): + model_type = "multi_modality" + vision_config: VisionConfig + aligner_config: AlignerConfig + language_config: LlamaConfig + + def __init__(self, **kwargs): + super().__init__(**kwargs) + vision_config = kwargs.get("vision_config", {}) + self.vision_config = VisionConfig(**vision_config) + + aligner_config = kwargs.get("aligner_config", {}) + self.aligner_config = AlignerConfig(**aligner_config) + + language_config = kwargs.get("language_config", {}) + if isinstance(language_config, LlamaConfig): + self.language_config = language_config + else: + self.language_config = LlamaConfig(**language_config) + + +class MultiModalityPreTrainedModel(PreTrainedModel): + config_class = MultiModalityConfig + base_model_prefix = "multi_modality" + _no_split_modules = [] + _skip_keys_device_placement = "past_key_values" + + +class MultiModalityCausalLM(MultiModalityPreTrainedModel): + + def __init__(self, config: MultiModalityConfig): + super().__init__(config) + + vision_config = config.vision_config + vision_cls = model_name_to_cls(vision_config.cls) + self.vision_model = vision_cls(**vision_config.params) + + aligner_config = config.aligner_config + aligner_cls = model_name_to_cls(aligner_config.cls) + self.aligner = aligner_cls(aligner_config.params) + + language_config = config.language_config + self.language_model = LlamaForCausalLM(language_config) + + def prepare_inputs_embeds(self, + input_ids: torch.LongTensor, + pixel_values: torch.FloatTensor, + images_seq_mask: torch.LongTensor, + images_emb_mask: torch.LongTensor, **kwargs): + """ + + Args: + input_ids (torch.LongTensor): [b, T] + pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] + images_seq_mask (torch.BoolTensor): [b, T] + images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] + + assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask) + + Returns: + input_embeds (torch.Tensor): [b, T, D] + """ + + bs, n = pixel_values.shape[0:2] + images = rearrange(pixel_values, "b n c h w -> (b n) c h w") + # [b x n, T2, D] + images_embeds = self.aligner(self.vision_model(images)) + + # [b x n, T2, D] -> [b, n x T2, D] + images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) + # [b, n, T2] -> [b, n x T2] + images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") + + # [b, T, D] + input_ids[input_ids < 0] = 0 # ignore the image embeddings + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + + # replace with the image embeddings + inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] + + return inputs_embeds + + +AutoConfig.register("vision", VisionConfig) +AutoConfig.register("aligner", AlignerConfig) +AutoConfig.register("multi_modality", MultiModalityConfig) +AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM) diff --git a/deepseek_vlm/models/processing_vlm.py b/deepseek_vlm/models/processing_vlm.py new file mode 100644 index 0000000..913ec93 --- /dev/null +++ b/deepseek_vlm/models/processing_vlm.py @@ -0,0 +1,351 @@ +from dataclasses import dataclass + +import numpy as np +from PIL.Image import Image +from typing import List, Dict, Union +import torch + +from transformers import AutoTokenizer, AutoImageProcessor +from transformers.processing_utils import ProcessorMixin +from transformers import LlamaTokenizerFast + +from deepseek_vlm.models.image_processing_vlm import VLMImageProcessor +from deepseek_vlm.utils.conversation import get_conv_template + + +class DictOutput(object): + + def keys(self): + return self.__dict__.keys() + + def __getitem__(self, item): + return self.__dict__[item] + + def __setitem__(self, key, value): + self.__dict__[key] = value + + +@dataclass +class VLChatProcessorOutput(DictOutput): + sft_format: str + input_ids: torch.Tensor + pixel_values: torch.Tensor + num_image_tokens: torch.IntTensor + + def __len__(self): + return len(self.input_ids) + + +@dataclass +class BatchedVLChatProcessorOutput(DictOutput): + sft_format: List[str] + input_ids: torch.Tensor + pixel_values: torch.Tensor + attention_mask: torch.Tensor + images_seq_mask: torch.BoolTensor + images_emb_mask: torch.BoolTensor + + def to(self, device, dtype=torch.bfloat16): + self.input_ids = self.input_ids.to(device) + self.attention_mask = self.attention_mask.to(device) + self.images_seq_mask = self.images_seq_mask.to(device) + self.images_emb_mask = self.images_emb_mask.to(device) + self.pixel_values = self.pixel_values.to(device=device, dtype=dtype) + return self + + +class VLChatProcessor(ProcessorMixin): + + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + + attributes = ["image_processor", "tokenizer"] + + system_prompt = ("You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language.") + + def __init__( + self, + image_processor: VLMImageProcessor, + tokenizer: LlamaTokenizerFast, + image_tag: str = "", + num_image_tokens: int = 576, + add_special_token: bool = False, + sft_format: str = "deepseek", + mask_prompt: bool = True, + ignore_id: int = -100, **kwargs + ): + + self.image_processor = image_processor + self.tokenizer = tokenizer + + image_id = self.tokenizer.vocab.get(image_tag) + if image_id is None: + special_tokens = [image_tag] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + print(f"Add image tag = {image_tag} to the tokenizer") + + self.image_tag = image_tag + self.num_image_tokens = num_image_tokens + self.add_special_token = add_special_token + self.sft_format = sft_format + self.mask_prompt = mask_prompt + self.ignore_id = ignore_id + + super().__init__(image_processor, tokenizer, image_tag, num_image_tokens, add_special_token, + sft_format, mask_prompt, ignore_id, **kwargs) + + def new_chat_template(self): + conv = get_conv_template(self.sft_format) + conv.set_system_message(self.system_prompt) + return conv + + def apply_sft_template_for_multi_turn_prompts( + self, + conversations: List[Dict[str, str]], + sft_format: str = "deepseek", + system_prompt: str = "" + ): + """ + Applies the SFT template to conversation. + + An example of conversation: + conversation = [ + { + "role": "User", + "content": " is Figure 1.\n is Figure 2.\nWhich image is brighter?", + "images": [ + "./multi-images/attribute_comparison_1.png", + "./multi-images/attribute_comparison_2.png" + ] + }, + { + "role": "Assistant", + "content": "" + } + ] + + Args: + conversations (List[Dict]): A conversation with a List of Dict[str, str] text. + sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". + system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". + + Returns: + sft_prompt (str): The formatted text. + """ + + conv = get_conv_template(sft_format) + conv.set_system_message(system_prompt) + for message in conversations: + conv.append_message(message["role"], message["content"].strip()) + sft_prompt = conv.get_prompt().strip() + + return sft_prompt + + @property + def image_token(self): + return self.image_tag + + @property + def image_id(self): + image_id = self.tokenizer.vocab.get(self.image_tag) + return image_id + + @property + def pad_id(self): + pad_id = self.tokenizer.pad_token_id + if pad_id is None: + pad_id = self.tokenizer.eos_token_id + + return pad_id + + def add_image_token( + self, + image_indices: List[int], + input_ids: torch.LongTensor, + ): + """ + + Args: + image_indices (List[int]): [index_0, index_1, ..., index_j] + input_ids (torch.LongTensor): [N] + + Returns: + input_ids (torch.LongTensor): [N + image tokens] + num_image_tokens (torch.IntTensor): [n_images] + """ + + input_slices = [] + + start = 0 + for index in image_indices: + if self.add_special_token: + end = index + 1 + else: + end = index + + # original text tokens + input_slices.append(input_ids[start: end]) + + # add image tokens, and set the mask as False + input_slices.append(self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)) + start = index + 1 + + # the left part + input_slices.append(input_ids[start:]) + + # concat all slices + input_ids = torch.cat(input_slices, dim=0) + num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices)) + + return input_ids, num_image_tokens + + def process_one( + self, + prompt: str = None, + conversations: List[Dict[str, str]] = None, + images: List[Image] = None, + **kwargs + ): + """ + + Args: + prompt (str): the formatted prompt; + conversations (List[Dict]): conversations with a list of messages; + images (List[ImageType]): the list of images; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - target_ids (torch.LongTensor): [N + image tokens] + - images (torch.FloatTensor): [n_images, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + assert prompt is None or conversations is None, "prompt and conversations cannot be used at the same time." + + if prompt is None: + # apply sft format + sft_format = self.apply_sft_template_for_multi_turn_prompts( + conversations=conversations, + sft_format=self.sft_format, + system_prompt=self.system_prompt + ) + else: + sft_format = prompt + + # tokenize + input_ids = self.tokenizer.encode(sft_format) + input_ids = torch.LongTensor(input_ids) + + # add image tokens to the input_ids + image_token_mask: torch.BoolTensor = input_ids == self.image_id + image_indices = image_token_mask.nonzero() + input_ids, num_image_tokens = self.add_image_token( + image_indices=image_indices, + input_ids=input_ids, + ) + + # load images + images_outputs = self.image_processor(images, return_tensors="pt") + + prepare = VLChatProcessorOutput( + sft_format=sft_format, + input_ids=input_ids, + pixel_values=images_outputs.pixel_values, + num_image_tokens=num_image_tokens + ) + + return prepare + + def __call__( + self, + *, + prompt: str = None, + conversations: List[Dict[str, str]] = None, + images: List[Image] = None, + force_batchify: bool = True, + **kwargs + ): + """ + + Args: + prompt (str): the formatted prompt; + conversations (List[Dict]): conversations with a list of messages; + images (List[ImageType]): the list of images; + force_batchify (bool): force batchify the inputs; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - images (torch.FloatTensor): [n_images, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + prepare = self.process_one(prompt=prompt, conversations=conversations, images=images) + + if force_batchify: + prepare = self.batchify([prepare]) + + return prepare + + def batchify(self, prepare_list: List[VLChatProcessorOutput]) -> BatchedVLChatProcessorOutput: + """ + Preprocesses the inputs for multimodal inference. + + Args: + prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput. + + Returns: + BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference. + """ + + batch_size = len(prepare_list) + sft_format = [] + n_images = [] + seq_lens = [] + for prepare in prepare_list: + n_images.append(len(prepare.num_image_tokens)) + seq_lens.append(len(prepare)) + + input_token_max_len = max(seq_lens) + max_n_images = max(1, max(n_images)) + + batched_input_ids = torch.full((batch_size, input_token_max_len), self.pad_id).long() # FIXME + batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long() + batched_pixel_values = torch.zeros((batch_size, max_n_images, *self.image_processor.default_shape)).float() + batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool() + batched_images_emb_mask = torch.zeros((batch_size, max_n_images, self.num_image_tokens)).bool() + + for i, prepare in enumerate(prepare_list): + input_ids = prepare.input_ids + seq_len = len(prepare) + n_image = len(prepare.num_image_tokens) + # left-padding + batched_attention_mask[i, -seq_len:] = 1 + batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids) + batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id + + if n_image > 0: + batched_pixel_values[i, :n_image] = prepare.pixel_values + for j, n_image_tokens in enumerate(prepare.num_image_tokens): + batched_images_emb_mask[i, j, :n_image_tokens] = True + + sft_format.append(prepare.sft_format) + + batched_prepares = BatchedVLChatProcessorOutput( + input_ids=batched_input_ids, + attention_mask=batched_attention_mask, + pixel_values=batched_pixel_values, + images_seq_mask=batched_images_seq_mask, + images_emb_mask=batched_images_emb_mask, + sft_format=sft_format + ) + + return batched_prepares diff --git a/deepseek_vlm/models/projector.py b/deepseek_vlm/models/projector.py new file mode 100644 index 0000000..0f9b419 --- /dev/null +++ b/deepseek_vlm/models/projector.py @@ -0,0 +1,80 @@ +from attrdict import AttrDict +import torch +import torch.nn as nn +from typing import Union, Tuple + + +class MlpProjector(nn.Module): + + def __init__(self, cfg): + + super().__init__() + + self.cfg = cfg + + if cfg.projector_type == "identity": + modules = nn.Identity() + + elif cfg.projector_type == "linear": + modules = nn.Linear(cfg.input_dim, cfg.n_embed) + + elif cfg.projector_type == "mlp_gelu": + mlp_depth = cfg.get("depth", 1) + modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) + modules = nn.Sequential(*modules) + + elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu": + mlp_depth = cfg.get("depth", 1) + self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) + self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) + + modules = [] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) + modules = nn.Sequential(*modules) + + else: + raise ValueError(f"Unknown projector type: {cfg.projector_type}") + + self.layers = modules + + def forward(self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]): + """ + + Args: + x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor, + then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x); + otherwise it is the feature from the single vision encoder. + + Returns: + x (torch.Tensor): [b, s, c] + """ + + if isinstance(x_or_tuple, tuple): + # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu": + high_x, low_x = x_or_tuple + high_x = self.high_up_proj(high_x) + low_x = self.low_up_proj(low_x) + x = torch.concat([high_x, low_x], dim=-1) + else: + x = x_or_tuple + + return self.layers(x) + + +if __name__ == "__main__": + cfg = AttrDict( + input_dim=1024, + n_embed=2048, + depth=2, + projector_type="low_high_hybrid_split_mlp_gelu" + ) + inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024)) + + m = MlpProjector(cfg) + out = m(inputs) + print(out.shape) diff --git a/deepseek_vlm/models/sam.py b/deepseek_vlm/models/sam.py new file mode 100644 index 0000000..040006c --- /dev/null +++ b/deepseek_vlm/models/sam.py @@ -0,0 +1,562 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +from dataclasses import dataclass +from functools import partial +from typing import List, Optional, Tuple, Type, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + downsample_channels: Tuple[int, ...] = (512, 1024), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + downsample_channels (list): Channels for downsampling layers. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + in_channels = out_chans + downsamples = [] + for i in range(len(downsample_channels)): + out_channels = downsample_channels[i] + downsamples.append( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False, + ) + ) + in_channels = out_channels + self.downsamples = nn.Sequential(*downsamples) + + self.sam_hd = True + if self.sam_hd: + self.hd_alpha_downsamples = nn.Parameter(torch.zeros(1)) + # self.neck_hd = nn.Linear(embed_dim, embed_dim) + self.neck_hd = copy.deepcopy(self.neck) + # self.downsamples_hd = copy.deepcopy(self.downsamples) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + global_features = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if self.sam_hd and blk.window_size == 0: + global_features.append(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + x_dtype = x.dtype + x = F.interpolate(x.float(), size=(96, 96), mode='bilinear', align_corners=False).to(x_dtype) + x = self.downsamples(x) + + if self.sam_hd: + first_global_feature = self.neck_hd(global_features[0].permute(0, 3, 1, 2)) + x_dtype = first_global_feature.dtype + first_global_feature = F.interpolate(first_global_feature.float(), size=(96, 96), mode='bilinear', align_corners=False) + first_global_feature = self.downsamples(first_global_feature.to(x_dtype)) + x = x + first_global_feature * self.hd_alpha_downsamples + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + def do_attention(q, k, v): + attn = (q * self.scale) @ k.transpose(-2, -1) + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + + return x + + # from haiscale.utils import on_demand_checkpoint + # x = on_demand_checkpoint(do_attention, q, k, v) + x = do_attention(q, k, v) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x + + +@dataclass +class SAMViTCfg: + image_size: Union[Tuple[int, int], int] = 1024 + width: int = 1024 + layers: int = 23 + heads: int = 16 + patch_size: int = 16 + window_size: int = 14 + prompt_embed_dim: int = 256 + global_attn_indexes: Union[List[int], Tuple[int]] = (5, 11, 17, 23) + downsample_channels: Union[List[int], Tuple[int]] = (512, 1024) + + +SAM_MODEL_CONFIG = { + "sam_vit_b": { + "width": 768, + "layers": 12, + "heads": 12, + "global_attn_indexes": [2, 5, 8, 11], + "downsample_channels": () + }, + + "sam_b_downsample": { + "width": 768, + "layers": 12, + "heads": 12, + "global_attn_indexes": [2, 5, 8, 11], + "downsample_channels": (512, 1024) + }, + + "sam_vit_l": { + "width": 1024, + "layers": 24, + "heads": 16, + "global_attn_indexes": [5, 11, 17, 23], + "downsample_channels": () + }, + "sam_vit_h": { + "width": 1280, + "layers": 32, + "heads": 16, + "global_attn_indexes": [7, 15, 23, 31], + "downsample_channels": () + }, +} + + +def create_sam_vit( + model_name: str = "sam_b_downsample", + image_size: int = 1024, + ckpt_path: str = "", + **kwargs +): + assert model_name in SAM_MODEL_CONFIG.keys(), f"model name: {model_name} should be in {SAM_MODEL_CONFIG.keys()}" + + sam_cfg = SAMViTCfg(**SAM_MODEL_CONFIG[model_name]) + image_encoder = ImageEncoderViT( + depth=sam_cfg.layers, + embed_dim=sam_cfg.width, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=sam_cfg.heads, + patch_size=sam_cfg.patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=sam_cfg.global_attn_indexes, + window_size=14, + out_chans=sam_cfg.prompt_embed_dim, + downsample_channels=sam_cfg.downsample_channels + ) + + if ckpt_path: + state_dict = torch.load(ckpt_path) + image_encoder.load_state_dict(state_dict, strict=False) + print(f"SAM-ViT restores from {ckpt_path}") + + return image_encoder + + +if __name__ == '__main__': + x = torch.zeros(2, 3, 1024, 1024).bfloat16() + # x.permute(0, 3, 1, 2) + net = create_sam_vit().bfloat16() + out = net(x) + print(x.shape, out.shape) diff --git a/deepseek_vlm/models/siglip_vit.py b/deepseek_vlm/models/siglip_vit.py new file mode 100644 index 0000000..0798c37 --- /dev/null +++ b/deepseek_vlm/models/siglip_vit.py @@ -0,0 +1,605 @@ +# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py +from dataclasses import dataclass +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Final, Optional, Callable, Union, Tuple, List, Set, Dict, Type, Literal, Sequence +import math +from functools import partial +import warnings +from timm.layers import ( + PatchEmbed, Mlp, DropPath, + AttentionPoolLatent, PatchDropout, resample_abs_pos_embed, LayerType +) +from timm.models._manipulate import named_apply, checkpoint_seq + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (torch.Tensor, float, float, float, float) -> torch.Tensor + r""" The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first + convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype. + Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn + from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + + with torch.no_grad(): + dtype = tensor.dtype + tensor_fp32 = tensor.float() + tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b) + tensor_dtype = tensor_fp32.to(dtype=dtype) + tensor.copy_(tensor_dtype) + + +def init_weights(self): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) + trunc_normal_(self.latent, std=self.latent_dim ** -0.5) + + +def init_weights_vit_timm(module: nn.Module, name: str = '') -> None: + """ ViT weight initialization, original timm impl (for reproducibility) """ + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +class Attention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + # self.fused_attn = use_fused_attn() + self.fused_attn = True + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0. else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + init_values: Optional[float] = None, + drop_path: float = 0., + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + dynamic_img_size: Final[bool] + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: Literal['', 'avg', 'token', 'map'] = 'token', + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_norm: bool = False, + init_values: Optional[float] = None, + class_token: bool = True, + no_embed_class: bool = False, + reg_tokens: int = 0, + pre_norm: bool = False, + fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, + drop_rate: float = 0., + pos_drop_rate: float = 0., + patch_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '', + embed_layer: Callable = PatchEmbed, + norm_layer: Optional[LayerType] = None, + act_layer: Optional[LayerType] = None, + block_fn: Type[nn.Module] = Block, + mlp_layer: Type[nn.Module] = Mlp, + ignore_head: bool = False + ) -> None: + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of image input channels. + num_classes: Mumber of classes for classification head. + global_pool: Type of global pooling for final sequence (default: 'token'). + embed_dim: Transformer embedding dimension. + depth: Depth of transformer. + num_heads: Number of attention heads. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: Enable bias for qkv projections if True. + init_values: Layer-scale init values (layer-scale enabled if not None). + class_token: Use class token. + no_embed_class: Don't include position embeddings for class (or reg) tokens. + reg_tokens: Number of register tokens. + fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. + drop_rate: Head dropout rate. + pos_drop_rate: Position embedding dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + weight_init: Weight initialization scheme. + embed_layer: Patch embedding layer. + norm_layer: Normalization layer. + act_layer: MLP activation layer. + block_fn: Transformer block layer. + """ + super().__init__() + assert global_pool in ('', 'avg', 'token', 'map') + assert class_token or global_pool != 'token' + use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm + # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + # act_layer = get_act_layer(act_layer) or nn.GELU + norm_layer = partial(nn.LayerNorm, eps=1e-6) + act_layer = nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_prefix_tokens = 1 if class_token else 0 + self.num_prefix_tokens += reg_tokens + self.num_reg_tokens = reg_tokens + self.has_class_token = class_token + self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg) + self.dynamic_img_size = dynamic_img_size + self.grad_checkpointing = False + self.ignore_head = ignore_head + + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, + **embed_args, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None + embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) + self.pos_drop = nn.Dropout(p=pos_drop_rate) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + self.patch_drop = nn.Identity() + self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + init_values=init_values, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + mlp_layer=mlp_layer, + ) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + + # Classifier Head + if global_pool == 'map': + AttentionPoolLatent.init_weights = init_weights + self.attn_pool = AttentionPoolLatent( + self.embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + ) + else: + self.attn_pool = None + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if weight_init != 'skip': + self.init_weights(weight_init) + + def init_weights(self, mode: Literal['jax', 'jax_nlhb', 'moco', ''] = '') -> None: + assert mode in ('jax', 'jax_nlhb', 'moco', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + trunc_normal_(self.pos_embed, std=.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + @torch.jit.ignore + def no_weight_decay(self) -> Set: + return {'pos_embed', 'cls_token', 'dist_token'} + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + return dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool = None) -> None: + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg', 'token', 'map') + if global_pool == 'map' and self.attn_pool is None: + assert False, "Cannot currently add attention pooling in reset_classifier()." + elif global_pool != 'map ' and self.attn_pool is not None: + self.attn_pool = None # remove attention pooling + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + if self.dynamic_img_size: + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed( + self.pos_embed, + (H, W), + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, + ) + x = x.view(B, -1, C) + else: + pos_embed = self.pos_embed + + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) + if self.reg_token is not None: + to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) + + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + pos_embed + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + x = x + pos_embed + + return self.pos_drop(x) + + def _intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + ) -> List[torch.Tensor]: + outputs, num_blocks = [], len(self.blocks) + take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n) + + # forward pass + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in take_indices: + outputs.append(x) + + return outputs + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + reshape: bool = False, + return_prefix_tokens: bool = False, + norm: bool = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + """ Intermediate layer accessor (NOTE: This is a WIP experiment). + Inspired by DINO / DINOv2 interface + """ + # take last n blocks if n is an int, if in is a sequence, select by matching indices + outputs = self._intermediate_layers(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs] + outputs = [out[:, self.num_prefix_tokens:] for out in outputs] + + if reshape: + grid_size = self.patch_embed.grid_size + outputs = [ + out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + + if return_prefix_tokens: + return tuple(zip(outputs, prefix_tokens)) + return tuple(outputs) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + x = self.norm(x) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + if self.attn_pool is not None: + x = self.attn_pool(x) + elif self.global_pool == 'avg': + x = x[:, self.num_prefix_tokens:].mean(dim=1) + elif self.global_pool: + x = x[:, 0] # class token + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + if not self.ignore_head: + x = self.forward_head(x) + return x + + +@dataclass +class SigLIPVisionCfg: + width: int = 1152 + layers: Union[Tuple[int, int, int, int], int] = 27 + heads: int = 16 + patch_size: int = 14 + image_size: Union[Tuple[int, int], int] = 336 + global_pool: str = "map" + mlp_ratio: float = 3.7362 + class_token: bool = False + num_classes: int = 0 + use_checkpoint: bool = False + + +SigLIP_MODEL_CONFIG = { + "siglip_so400m_patch14_384": { + "image_size": 336, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False + }, + + "siglip_so400m_patch14_224": { + "image_size": 224, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False + }, + + "siglip_large_patch16_384": { + "image_size": 384, + "patch_size": 16, + "width": 1024, + "layers": 24, + "heads": 16, + "mlp_ratio": 4, + "global_pool": "map", + "use_checkpoint": False + } +} + + +def create_siglip_vit( + model_name: str = "siglip_so400m_patch14_384", + image_size: int = 384, + select_layer: int = -1, + ckpt_path: str = "", + **kwargs +): + + assert model_name in SigLIP_MODEL_CONFIG.keys(), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}" + + vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name]) + + if select_layer <= 0: + layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1) + else: + layers = min(vision_cfg.layers, select_layer) + + model = VisionTransformer( + img_size=image_size, + patch_size=vision_cfg.patch_size, + embed_dim=vision_cfg.width, + depth=layers, + num_heads=vision_cfg.heads, + mlp_ratio=vision_cfg.mlp_ratio, + class_token=vision_cfg.class_token, + global_pool=vision_cfg.global_pool, + ignore_head=kwargs.get("ignore_head", True), + weight_init=kwargs.get("weight_init", "skip"), + num_classes=0 + ) + + if ckpt_path: + state_dict = torch.load(ckpt_path, map_location="cpu") + + incompatible_keys = model.load_state_dict(state_dict, strict=False) + print(f"SigLIP-ViT restores from {ckpt_path},\n" + f"\tincompatible_keys:', {incompatible_keys}.") + + return model diff --git a/deepseek_vlm/utils/__init__.py b/deepseek_vlm/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deepseek_vlm/utils/conversation.py b/deepseek_vlm/utils/conversation.py new file mode 100644 index 0000000..d74fc4e --- /dev/null +++ b/deepseek_vlm/utils/conversation.py @@ -0,0 +1,326 @@ +""" +From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py +""" + +import dataclasses +from enum import IntEnum, auto +from typing import Dict, List + + +class SeparatorStyle(IntEnum): + """Separator styles.""" + + ADD_COLON_SINGLE = auto() + ADD_COLON_TWO = auto() + ADD_COLON_SPACE_SINGLE = auto() + NO_COLON_SINGLE = auto() + NO_COLON_TWO = auto() + ADD_NEW_LINE_SINGLE = auto() + LLAMA2 = auto() + CHATGLM = auto() + CHATML = auto() + CHATINTERN = auto() + DOLLY = auto() + RWKV = auto() + PHOENIX = auto() + ROBIN = auto() + DeepSeek = auto() + PLAIN = auto() + ALIGNMENT = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that manages prompt templates and keeps all conversation history.""" + + # The name of this template + name: str + # The template of the system prompt + system_template: str = "{system_message}" + # The system message + system_message: str = "" + # The names of two roles + roles: List[str] = (("USER", "ASSISTANT"),) + # All messages. Each item is (role, message). + messages: List[List[str]] = () + # The number of few shot examples + offset: int = 0 + # The separator style and configurations + sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE + sep: str = "\n" + sep2: str = None + # Stop criteria (the default one is EOS token) + stop_str: str = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + + def get_prompt(self) -> str: + """Get the prompt for generation.""" + system_prompt = self.system_template.format(system_message=self.system_message) + + if self.sep_style == SeparatorStyle.DeepSeek: + seps = [self.sep, self.sep2] + if system_prompt == "" or system_prompt is None: + ret = "" + else: + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.LLAMA2: + seps = [self.sep, self.sep2] + if self.system_message: + ret = system_prompt + else: + ret = "[INST] " + for i, (role, message) in enumerate(self.messages): + tag = self.roles[i % 2] + if message: + if type(message) is tuple: # multimodal message + message, _ = message + if i == 0: + ret += message + " " + else: + ret += tag + " " + message + seps[i % 2] + else: + ret += tag + return ret + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = "" + for i, (role, message) in enumerate(self.messages): + if message: + if type(message) is tuple: + message, _, _ = message + if i % 2 == 0: + ret += message + seps[i % 2] + else: + ret += message + seps[i % 2] + else: + ret += "" + return ret + elif self.sep_style == SeparatorStyle.ALIGNMENT: + seps = [self.sep, self.sep2] + ret = "" + for i, (role, message) in enumerate(self.messages): + if message: + if type(message) is tuple: + message, _, _ = message + if i % 2 == 0: + ret += '\n' + seps[i % 2] + else: + ret += message + seps[i % 2] + else: + ret += "" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def get_prompt_for_current_round(self, content=None): + """Get current round formatted question prompt during sft training""" + if self.sep_style == SeparatorStyle.PLAIN: + formatted_question = "\n" + elif self.sep_style == SeparatorStyle.DeepSeek: + formatted_question = f"{self.roles[0]}: " +content.strip() + self.sep + f"{self.roles[1]}:" + else: + raise ValueError(f"Unsupported sep_style: {self.sep_style}") + return formatted_question + + def set_system_message(self, system_message: str): + """Set the system message.""" + self.system_message = system_message + + def append_message(self, role: str, message: str): + """Append a new message.""" + self.messages.append([role, message]) + + def reset_message(self): + """Reset a new message.""" + self.messages = [] + + def update_last_message(self, message: str): + """Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + """ + self.messages[-1][1] = message + + def to_gradio_chatbot(self): + """Convert the conversation to gradio chatbot format.""" + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def to_openai_api_messages(self): + """Convert the conversation to OpenAI chat completion format.""" + system_prompt = self.system_template.format(system_message=self.system_message) + ret = [{"role": "system", "content": system_prompt}] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append({"role": "user", "content": msg}) + else: + if msg is not None: + ret.append({"role": "assistant", "content": msg}) + return ret + + def copy(self): + return Conversation( + name=self.name, + system_template=self.system_template, + system_message=self.system_message, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + stop_str=self.stop_str, + stop_token_ids=self.stop_token_ids, + ) + + def dict(self): + return { + "template_name": self.name, + "system_message": self.system_message, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + } + + +# A global registry for all conversation templates +conv_templates: Dict[str, Conversation] = {} + + +def register_conv_template(template: Conversation, override: bool = False): + """Register a new conversation template.""" + if not override: + assert template.name not in conv_templates, f"{template.name} has been registered." + + conv_templates[template.name] = template + + +def get_conv_template(name: str) -> Conversation: + """Get a conversation template.""" + return conv_templates[name].copy() + + +# llava_llama2 template +register_conv_template( + Conversation( + name="llava_llama2", + system_message="You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + system_template="[INST] <>\n{system_message}\n<>\n\n", + roles=("[INST]", "[/INST]"), + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + stop_token_ids=[2], + ) +) + +# llama2 template +# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212 +register_conv_template( + Conversation( + name="llama-2", + system_template="[INST] <>\n{system_message}\n<>\n\n", + roles=("[INST]", "[/INST]"), + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + stop_token_ids=[2], + ) +) + + +# deepseek template +register_conv_template( + Conversation( + name="deepseek", + system_template="{system_message}", + # system_message="You are a helpful assistant. Please answer truthfully and write out your " + # "thinking step by step to be sure you get the right answer.", + system_message="", + roles=("User", "Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.DeepSeek, + sep="\n\n", + sep2="<|end▁of▁sentence|>", + stop_token_ids=[100001], + stop_str=["User:", "<|end▁of▁sentence|>"] + ) +) + +register_conv_template( + Conversation( + name="plain", + system_template="", + system_message="", + roles=("", ""), + messages=(), + offset=0, + sep_style=SeparatorStyle.PLAIN, + sep="", + sep2="", + stop_token_ids=[2], + stop_str=[''], + ) +) + + +register_conv_template( + Conversation( + name="alignment", + system_template="", + system_message="", + roles=("", ""), + messages=(), + offset=0, + sep_style=SeparatorStyle.ALIGNMENT, + sep="", + sep2="", + stop_token_ids=[2], + stop_str=[''], + ) +) + + +if __name__ == "__main__": + + # print("Llama-2 template:") + # conv = get_conv_template("llama-2") + # conv.set_system_message("You are a helpful, respectful and honest assistant.") + # conv.append_message(conv.roles[0], "Hello!") + # conv.append_message(conv.roles[1], "Hi!") + # conv.append_message(conv.roles[0], "How are you?") + # conv.append_message(conv.roles[1], None) + # print(conv.get_prompt()) + + # print("\n") + + print("deepseek template:") + conv = get_conv_template("deepseek") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi! This is Tony.") + conv.append_message(conv.roles[0], "Who are you?") + conv.append_message(conv.roles[1], "I am a helpful assistant.") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) diff --git a/deepseek_vlm/utils/io.py b/deepseek_vlm/utils/io.py new file mode 100644 index 0000000..183c481 --- /dev/null +++ b/deepseek_vlm/utils/io.py @@ -0,0 +1,55 @@ +import json +import PIL.Image +from typing import Dict, List +import torch +from transformers import AutoModelForCausalLM +from deepseek_vlm.models import VLChatProcessor, MultiModalityCausalLM + + +def load_pretrained_model(model_path: str): + vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) + tokenizer = vl_chat_processor.tokenizer + + vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) + vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval() + + return tokenizer, vl_chat_processor, vl_gpt + + +def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]: + """ + + Args: + conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : + [ + { + "role": "User", + "content": "\nExtract all information from this image and convert them into markdown format.", + "images": ["./examples/table_datasets.png"] + }, + {"role": "Assistant", "content": ""}, + ] + + Returns: + pil_images (List[PIL.Image.Image]): the list of PIL images. + + """ + + pil_images = [] + + for message in conversations: + if "images" not in message: + continue + + for image_path in message["images"]: + pil_img = PIL.Image.open(image_path) + pil_img = pil_img.convert("RGB") + pil_images.append(pil_img) + + return pil_images + + +def load_json(filepath): + with open(filepath, "r") as f: + data = json.load(f) + return data diff --git a/images/latex_01.jpg b/images/latex_01.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7a0fde444f65d92f56b70e0d8edbdce3d508d8ca GIT binary patch literal 50179 zcmeFZcT`hPw=f(Vf;2$^X(A;cT|$-mlb+BdH0jc%geo9irARL!5CZ}Np?63^2kBiD zLQRwo(xrEN`Q7Ka>v_Jl?!E6H?^@ql-}>I2b z4_x+d-1oKn>#H))|HeJ_p-NY{?G^sm;oorUf5UCuJ^wBrdsRjj=HmT#TYtmf98=i2 z8t7f6q*oUUzzYBcr~wrJ)ql4f06^*h0JxF(FPb$90H_TI03J;Li^laC0H6*70C2gND( z0oVZ^0l)xPfDJ(C3IhS207L*1e+U4jtKs}VgTE{JKda&D=|5ipG^E!QZ+YIh#sRoa zbL|GrwLhHzwkvU8zkdD3->UdexI<1#MsoYc&0E*5Ulp*u3jmN@yK&?CjT^UaQ{ExJ zMFzNb{VJ2>HYp7mE&Dx@JNKV)=vvV+ae_SP85E%}-X(E~N*VefQc|<4M)6}07$u(? zcvfSuZ9p*zJ#Sw)vSxfv$=WOVx48c6H}7#Hi#u1X-2b=s|H;PRRGNQr^`F#f0L6_f zmT7L#02BaOU;at{zsdiV!2b^<(7Cz0u2ufyFt1E{Shb@X=FS2>tJ8ck2U<^%L#JIp3}MWGpmPy|+#COI2}ll|%ex?u|A=X+~tmVzcd2Ff|ZE^Iq_8 zdrQo%S5sMC0gX(9`MWKgj50LyFJwEkxaZ)IyYW)s5Q7Va+$|QHy&6gHbF89Za^i6$8ggY77fUBo?zp_32QuoiiYsL{}tO$kZ{aqe<% zT{q3jKs8p(FR{%AdK@tJAxj@vbU~ zs<0wQCygS-1kB~I-{rTv{O21(KSCHD{{c{lUgFsPjhU{uS=-NDm`_Ak6^`@nT|5eD zepczWIE?-Sm`CXBzvVysf8ogkyVt3cVu-_3!t1gu7SC7+o5DHt8o=`F;E#zN&xGF{ z3hP!UQ|S$tLb{g^94=gR?yToNYHoU>T>+x~N1xc74`8Mdxgnb~}%teM1;fQ44SHcK^2>>Wn?eoui$K|E!D>`k?Q9gRKf=lG|B= zF1CJJ*azWCi^sAs^la42r%?&#ll5YM+J%>*Wt~b+aVUc*#)~e3K$k=fN5N{O3bhq=bY#$y4Q!4&8pkbr4I=|3nn8*d@FfP0M(zAsdozs zjXY2N0~nP0DW~s1)f}7~qk)De+>cDuJmiTdQ&s%icKQDT$dPNP-^cC=x-4>VkwoIJb7&E+9di`i=sLLH$X^(FfsRxIRS#E1I8<)B z7*|vdh(_firdfiz&(dJ{ybcadA@#ZfUw>vTyj<3Luw|a#re?!RNtug9sosu7{0Bak zDnxV_O_P(P%zR!h59zU`T-R*ZFK)A>q%*cq8t8FcgTm-G1Go%3qx}lyX1F5-sj9_A zQ7ldl^;G>5Y#`E=j6$K`^9pDo;tZ@keHVWj_|fYZf|ydMlyRKDD2Z7^>6$f_4P1=0CLYnu_R30tuD zs>`831~`jERn@vNFrBfnQ)tCzisE*O;ChB0Xs(xo$Jj)3Ae@E~DNnujboqlby~F2I zRP37lEjR9D_Rm?2I5{6a%Po&$)&1R)&HTDMg>OVm>v_*h?i>P>zg2&YnQSog z6FQf4Kt&MT<{YeZdLfgyQaSl`2PU5Xo$XQ3Jou}8I!)Ut!?n4Z7@aa4YtoW za=k-crJt6+Xzu1IloFDJUEktQ0`{5HNpHXj5h=Vb+WzDf)!V!r2ejo|6a$$Qoea{$ zfg;{gWrCND8>cN0jvl2^PZn$aA_jW+O0QmmKQ?hvkIac*Qw?euR*lYsa4v?w&+Xul zfiASt?ca+9+f8mix=0=E)pLFH2XKpBRqgCAedIR!@uP^#z z?0-f_EEyBP6`s1v@LU+CuC_4>U;CrycT3K6{{U%frHxKFdyy`>nh@92-`72^p=i*r zZkQ{S2W}K|R``$R2u!f1UcTzZ!$qOJG2v_vzeq~`r%?Yr zy~$Wa5v9C*g2WQH4pBUh27(wD^I9mU$*S1aFVI zhcZH9niQW`q&5qObBc-{vVV7%l+2bu4K%o!79G26d`FQ+Z^cWq$ad5)55NnNmKV3fiZ`T>Ba0%!o_@*R7}Tn?;D~PNu@77I>jd9tVsJ z+)kUI*}-1x4;@{jHHJ0m^@PhGwo?;RKR;tFL;h+!4yyTVtc2)CEONM;F=D0920)8_g18i3%Xg%>?W1}3 z@2r(hGN&pz$ANMI11WOTn3@Fh9D!6@AhjTtiAz1wtaM6@CTMNt*LD4u?55%8hRBcd zI`KAy)amA*2*n|2%pLwuNcNajzLIfQ=_OFAOP078l0Fbskg2!11@yzp;T^TN_cKe3 zUWWop`>KkTBZvD-{B%TUhbXeW9`QPJi288m;|%Tkm=XgS-}b%yrz)^TL!l|Ws$hsZ zK9BLYGF(V=@G?x<43+SmEx0zwy7FsggN`I~*CyxixMmL4)#iTX%6VD!NM7r$gC3yN z2;HhQPRp!>*juOY*;Sh0)j$nK)fC{Qb0bs-E%$>2{0Nh<>5^r7Eetp!@3i1&4OhYY z-E#6n-+uj+JbrLe9vH|qFgcR57~6V-bMdc)Y^q-Gqriq$2Rya!z0Gp#zRLQT*a_Pt zvMwcM=BLT8;-b_wqI941=iF0S3t;CtxHn~osUFoVXg7gTqFWZTx!`{Q?02dEF**PM zRVy9_{enqNVtkmrbbm5~W09NBrGiJ)l6qXXAmYMzFo_t{)#i1;`u2%)q5LPevi^k*pGH*V{%KIp}q3FHnq zqCXSuyQAA;NaG*HM&gklp@})1x(r;w%a(6e-s4~yZZlqTG?Fow3m4rcGseI4rdTvR zas6Y<>%1bp5HnP@{i%GewqT7bEfX||j*Xmi_A-_0*K(QF=}f_!mb{mTRYD@f@?tsc z>t1Kc2R>q)PNK;+bj#|FMg z-A9^mFa5AC4c!?NGvk(myIOGdvmuq9x~(}T`PSwqyIS~Jl{!II_>y7OY6JiFg7HMh zQ;Cq|-{TLe>=68;V%aQT=D3}OM&$$DuY(T`pTKfN<#DF>ta`zMC=sGYXPbi2tRxkE&M zWvFxNl2=s>uXhF5TG?s!$sEv)b^)~Wd!`&4oCr?{*5}(F4f^Hkc751s&!5%!t3~sN zr}xOKf2x%U;^9$TSg~oME>fZGS=Ou&@~6Z%u@Zk1unl^;#$+>vd2`TAxpM7+M-=>{ z2G=gDo-4Bxn4JWo<`g*q7h(Fre5@gn9J)3F8q7#f7l7iriu4bLY*)Qrd1>N3>PTGUGb<~nj%jL6>{0EyUrBM(MPc{?~d4D}Yuru4q7{$Pi zPP)7S>LDEsQs{YIakC4vO&eSI@h7_uJv>b&8d!&a5G7U}UJ8esM8=*N-5Mr7t^1#3 zN}b`Mcm1%uJ#9o(kH2(Y?v9JmHw}AbM^cF&j(E14Xh9N3sNb%td?cXEa^CsZf{fb1 z^v_4AZS9U2J%vMw_U~S%r;NxBsAEa;#WP53o^L95c;r**2}m!vKAs1A z7cek(h|&0Br^_h@vN%4wRyB*)G{4)@;VSgWgka8itLqv%S43EkIRDCXQld~vkT8i; zm&7~kq6aDwV0qdb@Ah-o*a`z~6J7a{-=`JM;B_@6`F$^CM3!?@^OK~$cAE-_g(p`e z3rm%EG12C7aO!#r-JgzF-e31g6Ny~sP9s*q5|#{pWjq*hsRiN&6II>1U`OL^X{pM} zHG6$`cxdz3mC=PGXcUuvvRDzfMYpOMBV!OnLXxm$L}RI3>a@2>F026?cv{J_Wv(p} z*PqUA2~?b>Cw`bgaa(6P$=HAF*YTHQQi+=#u#?k|(eR{MhpjC{ zQ77Sx?W=#Q{$J>r{rRZ|?5OZoiLXhGgbUe_2xV9+RXNvKvb!k^7fg)O^k|dFA~9na zoisQZ{{yg}7OmF$xxyrgM$x4Z1A~9fX*jjBy!yae0f8it)qG@;#nYS5QzBnB^yBvPX8W(}Tu<$nO@qVQ zmH6}uP2)bn=Tb+4ns$c6XxISUMZf)F9Y_41nL3L6Sg0Z6J}@nzcL%k^yf4YnT{9?) z?xPV^3_}>a{xs{`a+5gSchSG*UPZKvJ?*y65!jfx-cuoX7XhdD96`a@is2j}W1fWZ z#qJ+}kyJI5yOl6aM&P}E7*uA^FnqVl2bu?LagGJ-BK*Gh9j0l7oyHa14F$4eX>~$H zg;_R!Kp?bECQ-a#PbbmY$kuywZ_l*?Hu@xG8d<#D9vEPvq_#MP?CFy5&%8U7V`%Fm zA`BO3#v{h_3GfOb95(6ku`bQq1hijka!+r#PKygb&HEcn=E!5O9Bo|{t`AHi-oQT=h}1>`xR zWYm${QO5%vdn}UY`j)^NQ$q>2786}JabWNJX=q*V{3>3zv)@|WB?564R&b`C&&VXF z-M)KeM35Z(^|wNVP}QZTdyE}xsGCMB)ISCnSuS#3V@!y$DYVVVfvFHpuQ!^>DGj&~ za?=@$ZqtoB3R_pUqqDmawOYzisMR?~zT8f;Y_3V{vX{rw8^uxq~Ci z$}y99)?>waFRIA8ZPS~JJy>*fQ4$b&^(K$uhcjI@b!1O}p8o*^oW{_##u(1NiOo~d z=r`r9NuZKuGjcLmg-3NKWK|I@^=NM#t6+Fzm0`yzcA&Wea+p&NZCL!9jaNp)z z`kA)DjAtja+J69ZET8dfagowZIJr*{yOoouF8e`rPe967giDc#ZRJR_=(F6ac2O{r z#0;5K!)pACZgkHWTWrx}|H4)A&_nvG7N#c3HS)l15V0;)SNoY;1heYZ`_p}XmGK7W$KRqVb{x}oWM!b4Ky29My$U&Sy@nzokl6% zpJ#d9M^&Vn3dg#e?X)pQPe{a}{Yxg4iZnL}&jrgj#9Srp9N1|ue7Pl+=rqc+XmKo# zloA6T8Di3T5(T3@nPg^(liOXnt;A6FVUNwimqt45Z!(`b+!iQ*zuV`Pt^H0a&1!r> zu*O@y2382)k;>HTW8hLX#U8&rWa*xg_;`r^F~-hd(q7Cm-k=!)hUQi-GR!xGN$v2C ztK+9UBM@OIB)95IF1H-%r{TY}KiM#Oy&w64F8G{RRH65TlYBjuZ)u!YbZZsOmbMQu#N%q**KikeEsK~I9cB1LsUVZRR!-jWbQM{~vVv$wl zhnL+xMgx*`SUKS=uu6aI3CF#?-dE0^(WWfH^D>_6f;PnhlV7u`B=-2A`man_Ac;Am zV75D8WP+7mt8q}HLy*`hg2$-RyE!2jD$2~#4OezEkFYVI%W!cZaN2c`gcS*}O}SOE zc-kk@e+BJ%GTJNgYs9vUx{1weB2sW`WUoy0lR;udd_bBc6yp^n?#2b~5BxBv+QMZQ zNXLQ2@1@*$*t64UG_uAjZiIf;-u|1Q{zh*XSX~E} zGN#l_y6Q9TWPT75E!uWG9%0tg0nO_iBrZ5X^(~>I>|204Tfso>bR4tpSd&^{fMbHh z`|gRyC^tX)l&ro=5ggyFmZ=`Q3>}C*rEK+P3NId`G91o${FL=IV9ideC8%svH0F?O-Pxq8`q&!uB)e{(;|A) zB*8yo-gb`1#@2BzQKQdk7>LSGugOP zr(F4quZwJ4y{0G>4MRQZx5}Ajm-ipQT{d$qt?F@K32g0Pxjj&gr0>P0RqFE>t;XJ# zc6Yvt@zS6aY1f5ti+;lC(WYftlVEo0iv1P|+*eNlSMX985MPT{2j z<>S!;sp_%SLSsPX;Go_f=630p*9&;x^n`HNTy4+W8Dhz>cBCLgwwgc3?PRhu;Q?&J48vLMzcz~ue&EU?b ztxUN?9yj{696ea$Ll+fZ%{l_w2?iFY)d}Z{9r&_zJ%1?CJMX8jW>9d|_f6yK>S4#7 z^X3;nhyv4VUCYkj@ZFDkrL+)SdCR$=y?C)4F*F{rYpWV)yijv7vU>SjOo?tBCR6R& zow>zNaph%+h>Ee#?(*pe-9=`K!lHx)F#J&yK)LD}gH%5yw_pN+BV_>zOHuLa z&Dw8PYLht*Potj|iaVI?6WrenX`|Pupu{FMS*sXyv-QGHc<}}?h8HZJUB1inz|!#o zu1j;pEpKUVp~PlG;^eA!k595!07-FzSVA6FZ?>-L0DX1FgXeT(rg+r;+&cEC`*#86XVFiO(_c>N1UCZC@MR=Tr8hI zUaxqoE+EcMx8N>PbRf_4+7(y@x&6L!6$0f zpCY{X5osktFHpCT^PQ=~p5`yFJmH>Z#_SL%d05ij8D#sjR!YCIxL02fPl_1j4vnL2 zUhq%R!NLbuMd_hS>Yox+c=dSg$n0p+AQYcOxsGS+gcRFo_R7lE3ul|{oO$cqKXU3- zvx;&(isI9~(vpzk6t+DovQnm+h+h$zwhif7#yKVQU61(se_EY7>4`i?=V@rTO7{u$ z;w>mUk@Nh5Yab?%uL<}G=eX+xEV@bff%9WS5wJ%D;cK{+GmV69jtUBhQ1v4-?Mt*> z&FjkMRTx)zB%m{R8on{3lb;c6Q$Fc^zq9dI`aLdn$c6TTjvJYk1z6cu=7vVCSFQOg zI8jbyOu$zNzLuCeRqDM|2Jwn-7dlKdD}gTzZ_-?!a_dmaboNOeb$Yn%D_r{e=}(PY z^skC4Qb@KqpOkWnNVGsL;rl$aXJzVJCX^MG!7>POhOy*=uiK|Gc6gJx&(fR4M-A5d zt0QT+bXRwCtss9kF&SW_U_UHWTr7+HO$=cZP%UaI$)CggI@gPHJm?wvOOc79nb{eV z6^6*O&3%T#5k4F?8uTP&$2pc%AT=GlNIcX{b4QJ%lh_O6Ag=9}cMuEFFip6&2%GQ#5qrW0_(+BjOlgP zKoIs74`&7PM>D|f-p+RmhF@<#3P)H=?(`d|{P0B7CFf?bNcX2A9*+bf8`QntsH`rOJz_)qIGVIA&94u~ud_$UoYr(j&+iTiK1Bm!w$jb$%y_|4wP-;?|D>N0aE z$@1fh*uEY|k$Pfh4+~@(>bHmhE5^WQeIJF-7*WMvs2s`F3mS~UTAfr)sTq6AZFPD) zj!&GN*3#^PE@HD2l7lpHL7jm?|8R{v!#Fr^! zYIEX+%YN+(ucjOcaKcPEMLyDV5jW({{p154iSOXT*7m^+iH&JdY!CmHXC4%m2PWWHR!&h0Pc(`_rS~%59}T{`yI%82NLre z^e|=YhlPx>(DJg)bN-h&H;vRCl*E9ottu8plXcTxWj29mxb$sh{vJMvzt@u$^dsH~ zv`60Ku?n5DV@RbYn3a3dpQ({9G<_V(EBs}L#&&e)2&s>lQRIwg%x>3|#V_reb55ZN zu}ZE|%`^Rz4j2Zns&(f(;imuo_XYFct74v!K9my~N0Y+k6Iw zezO>>DukRf=>*A=E;(^;i!b{GfNp;U^HLfNJU&Es6rUKPivye_3^+%DQm~|*mD(?n zsb{~<*Y;h0IC(I!n&WTT>R480c7edSM6aBvEGO%RM;FvNhJuca}(Zvak`Rd&e|?(7z(e)fLM}^mP+gK=q+YLIv&L{{zmGDX9TQp)w*=+m0U zU-5zvjxOWMs*a11=Shrgd$C%p%^-t_PM-=1=PfXjN3F`&R^8kDtMoj|tKQjYOzl7< z&fMAuS$_6~fThp;MPlF`wUh{Q5LnX6t3;tC`4C?lLk+g>6~_*Ecrw;;XBJuF=?{J0VnLOv}~0IqYVNi4jZc^w~nf6`tJ0$oEyz!7G7d%|WMNmNZNph(iJ`FY)#Y>C8>n5hdK=3qEdLeU2V z)e=f7;cZ}P=f^Z;e91SKGNH*pFf{eb(a!OFlD_7g_hYWK70f1|P%u6is<+Pn$v;{= z*E@+3TP@C&NSrGCv`bKPLIkZkIt>5T`9(W2YVJI8@v=h3TykJ-w|KvZ?5X&G2>~d` zjb`9iq)q{H+b#T((bOv`U|#(>qY*c&(+eFS&QETq@t|tqS>le_EXAuM zYM~HO&MSlN%7b$)?d5;H6A?QZ=gH3}$8*B!EuM1Y?n%d(TCUsFGgqbXX{uv)u2ws8 zt-M-_OUHTJ$YY`V@LeT$n-dRGZ*o{CoZxXsaU>Y5ZhYm~6s`x0a-LJ%y33sSO>-MI zjZIng`TSk$a~-<2#$L>ED6hcQP}GWBma$h+^VnXGl70DPwvUek#L4DSCmNU?&eAr<@rm=^QA|11YN0+RdUERh#=uGs z3jDHKIA#Nfw-r28;P;kide%EuUWnfz001Itf&O0(-@#a6~*h?m2n`I8OXO6FFSH3WWpkBmT6{ufBW8tIpth1k}FtK=V%@>LFR8> z83gWJ0PmASpPUsN-CJX#v?Y|Pb8rZc%fGlfY@55A6*i%yhjg<@elhbh>!Q04TcTG2 z<`2!Fzt^I9epa);Fxd*A_npP8_Eg47yd&LDQj1=||8j|W4--$hI|pP5k|BGq!=`G%4YJ8No*hn3#L7dRK#bZv;DHXi` zYC{j6vpJz17BABHUIa7aI*kF+PQN|vZPlBB3*FY;UQog^xXZZWGoFqRtGDkCR#VuU zdkbdV$FzUMPh?rW@~FQ9iNa=+8r|8Hf2wf(&RO!WE<8a|tb!junC@{}Tnw>i*mF)w z7sE6Vs%CVvI|Cb@^4uMP+x_TEx|romAG0)OCPp0XIgj%omXnhv=Lh>BL-^C@!gg(L zm=n=CUY9zZZ*}ap+Z_jPdD0@IoNkkwEm{UQIGmCWH_1j#-3o`7=;0o6VakC8S<9Ev zSstJGdGM^Si$|C3`m+$W(($q8_YiZ%8snp$w~RjJwSbMNo(kHpWo&}BbW0|+WORZb zW)@&3<76oU%*1)v3Y86sQ3mEHOs9HT!!$kEu>x_%WotM42LPa_UU-@+{;|j3LLA04 zE$f|En-S|IhQl&=LIachK)X++zOASVWu>P~rSfK|5mifKjwvhyh^NAWq#I`TI~3yA*yu#uHT>H52%t3)1697k_A#B zJ3WrrVcEououbPcrtdr|J{m3U<{RS)rI3TgxuxHMZz~!Af+dM#@M=hNKf7 zgw(R>>L+xx<*1diD!<+ng-;9xADM(&CSD!u8}m#& z{eB_}k|#ech}YP@$oYY1e++pZ28Y7uc!h5&TPj|=xa!|*{bAnAQ%lXJ53;>47Z96; zBlV+rbae`iFC$|fpL{^z8HM`{{qBCNtu$SVvqU~o+}u>{K~ySj>LMhGTLTK=Gl^QS z-YeAGfZs>cHu@ckg4jtZ=y{%Bwxq4XGK>~ag5x*S#4&4+ojTpP+cfK58$#LhJJ(!Z zKACC-wTODMY{JpJu1wHeA&PcqU4PVu6RiC)#}?ad|2fxWviB%Zkj?88kQ$qC4>wr! zXrb*w)nUO{h002o>WsO{dOW`kq5eTa1XX>;FUlnqs!MZNf;nD{nN&>hSD>U*je#Mh z#;$?`yQ|ID3^Q#&)uaNG94$-&rOy|Ao(~pjX#};qv{-d)DM5^QIk@7OUt`xN$_cmd zX1S_IoXIBkQi0&)9L=g;mweOVoOG(|rs0;Ri~^X@`1LKOaeE9Se<)YUR}DlxrZ8;}ujkjyu>v8}sT1|)E4kgrXTSu^OI>eMRP zi4<;^7xxROjuhRssa}YH>MJ{gI6=E&gs~3B&eITA4($({$sP^4LC9GKRupSqH0Ky3 z{sU+KzPY_7*o$_z@32q4Zx-ROc&nvyS+z2{bJ8{MxZ0HGg&$k-Jay*%tS`qR8Bl>{KCeOzu8Wo*Kt%;9CzSf(^{q_L1N-1 zY{{&A2rRx~5P{POLif1#bwDUCA_q?tCN1450f7(GODl10!V;~Z?`{Z2KUdpGi0{GO zdU|SLFjDp^Hf7Gt;}0NkBm8*?g=^W+x1VT6Iww1}-u~RCE}%1&XlyKo`barn=wp zJ$u~GysNLaR@=sy#E^U_^7Q9B!d$OpE`~3Fq8vQjCwAEYk6!MZ_>?5{dWxbI*s~-% z$2Gnd1?Dum)kg#PS$L2?_z&l_@|&teSNYxHFp2k9YgcUjP9JhV!B%N~zG{u8u+jAU zkWibLgFsgNqVeX~CW(qfahp4jQICoXN>jc+jYmuSUe$=QHt?phL)+f$&H7iKX%rc> zEIBm^){*h|p!eGf`|=7KqxcqM`jAPzw}RO2(EmD8DM?F4={r${q^!|DUVh&0Qef*d zJY;r@?zNr7fck!reAf}L+Ru5N%r~5iFN6Sm=EpJn9EQZx&GRKueCBtM2}CIzQPqLzsyzd{+(`UAz{j)vDIJa*X}Xq z)mPX^9(pd<1VB)$m=Cu$#3j9CjE!2W^koXw3uo^49TU4;Db5|g+As0Pk4D-N7UHEA zz7QhAYunb#yj*YvUT81$r<5wxG6Sq=f>JxL8u=3Q86MmVfht>El4^1NQxE^A*pL8Q)Fk_6@OjFH8z{c|KY7 zkk_G~5S)~#Jl1-xFkFKJn`~JxnR1TJ#B@yeORUATojUoaLCJ7CHhz_4ae288Qtxs4 zj@ zO@QD=bCN|J?w<0k3RjgJhvAb&PhJ$+71&pntEWM(tj}&g&^{>fbQcM1BfdOs|G3 zn{SmlanVR9Ba(hATadpa->Mxqoy1I>UpG8v>t`KCMtqg^X4DX@5s6mO2wlf~msl^- zvt1SGJEGZ}1|ok4Ume!?5ybXplM}PpaHem$FQrn0!|LDL7aG&Tn>08fQOc^Skb;k@ zs%o?650-MH9kNcSQCIG36vIW*lUqE^I@xL^KZOdg*D&4B%%`Q5+o2OwLTf`1U4Prm z`wlC0O?wCTR*!Xjj1BTlZLduDlxBSyayh3S-p!JYQckRdt%G&rJtRjj;)8i?3!}V# zJ3%)oi)Zn#vm!YUQ@0YyQe98fl-S3s@ROwE+ElQ-#`^Y#;PYPwvEq@h>!!D4751vb#UilGQaZp zOk?x}f0zFO6p)xvJLfy7YE{3+C)vSQmkLIp2z26clD77~jPL`R>8x)qB%me$xZ^epz4c{<5h%uw_84INJi zxF#1ImDH94{VadxWa?vNpi5l&ZlXDYMrP_H6yVJr-j6+BZm`NfWMrk3Tm9ND!YofH zt~0bnE-f64+@It2qK2DqdG+^3nO&$z{^0PbMQ*JJNRh}H#1GcLN3FMfJkv}bs@{}~ zGaJU4p^j+5kj2Q}$b{bdlAA^5B-9b#r9xr-W_k~cOzq8>7|eC4m=f*w#fp-_$k`i- z&nn&-(G?mb<`^kj$(447GJZ@l*4Z&hV5P)NhunM=^ZViC-_vAaz`R=d2lJNI{)i6U zF|KN(pAzXx_6x0u&kX)?n1XuCi-oJ`C7RmpunwCC`DlBIv$$b?BVWeidhGs^Y`+c7 zM$s}@BEzxB(GSf_f^p9uNnyCDTyayJZHMmIlv1whz|#!U8~{ zeGKwk9Wz1ItwDp9RB<+p@i5-5A`BOLf&;s5}GM-pa??nrEXOGmIDNW^(9GrdrvqyPm_=&ytZpvcf^BF zv1wO9BGxibrQrlEcERjhhfP3IE_BS{-G&g8)>)*kLOp-hY3}>f^8ApgrLvG8&;9^V ze+dCrvoGE@Ud89&zw-Y9d~tVp)Yg*5Hc5`zzR;X+J-ax^KUq6Weimso4F>ERi6RTa6tym*pDtcCyPK zf_N-meU4NQjUB|>+uN|KRZt?-qUdF8!LZ0d(0gHPB`qKs$2l1<$HOmo;(~#Je6{2= zrWNCAPM~dFxJjShVc6vhuJwS3hDeDXArY4dX3=6{{w2e&R?BIxkxVuayLqs5I2!Rv zH;L_fc}3K1G_ZRVTTHMJYh?BNAuwF-0+i^ob>QM;dq;^9Vov35yZ`3u<2yxY1vp0( zs<9?0^&qQhsNiKl&$p%pRj-eC7ghRyE$WL1Ioq$hw^((L*HU6|^9WTi2*iat`F%ac zqcJ{OPJ7Qa>ny3qHh=0ISB%5cR~b)oWwD*Ha}p;-P%yf*F(A%@ zMB+eU($KrIx!2CbE9u){aSx9P0n(0cevsQ;`8o28LY$8&@bRR$OdAtCdF7Xq=_`{& zOc{xzi9X~i*!6J&Tw(QEV0rkuD|MZqWRJI`&L+pSP5pf_jfitcEeQDYNT{eJHJeoD_ccfL~=doz|qIwz(6C6t{4_=q0|#wwxilbY|BcqTP8 z;usk{h^neeUNg7LiLLHw8imti+_xvkD73Rh&`lGb`h`9Q6^zw+NpOATSp!AE4Ys=25A>|OwPNGSWTe3!9v6cSZ9>*e%8e5o0}rxTYP%j z?dbmZx^6GBJsKtT6SOA$qz2Ud-v~*G?OI%YD|f^53*h<#g$J3AqD+L#aJxPtRQd7IosP!e$G7kT~koMH%Mrqd?=?<+h>okHbz^w60eat^9Nu? z0Y4t?2d{GDo9`Lr8>>Ed^VlSuJRF`O^vDKWMG8>=e#ycy@*{+YOhL*&N%t4(veseI zL8eal*U=xq6S9BAuv6YWS2byh<5~?bouIKhrSC^64%&ZPaUO70beVSl_Fn#g=FLjN zt%Xe0qc8z^VVK9o;43?cxbu~QGhDVke*S1@JuvKq($~k+oao==)wVOdA12ly?Z|O( z7W~Fo2$WY|@uYZ=C^Ty6h)p~{{iI97R)rosIZH_L3UXE+3Z#d_!ID%UGA?Ch^U`ej zuC}0C7OkY2(ZO4iqseyk-!32~7f-RC{9kkH;aL$g+*5N5-tObFeNJ-Gn5v~yAqoS=QQ!tJ!Zde{xwi7)laod>$_Is z>XvD`Ef$nFrr8u8pK7Ff{;R}o|2r{;bi_rF`q5~|hhfJP(5|?_4)cLG9_BF4;{(J6 zOwDkfIN6ILSXt2WUmZx=1+duL-y{=peYFd?m1#vkiI{lhP-&9axuY4MH}@#USw1`r z&q}2w6}V& zsT|l4q^!WmUb9iH3 zD)Y%_SK9Wk%6Gsbk4#8nIEwwy4`!da3-mdA5ZT&IgMKj@Ak{7&eP4aYO9CtNuv9CD zrZWK%Sv=xcKCkm0%y4U^hY0eWb0QAm%3a)ON6Dk8HBkA-^IT8aC%~#JLmHthb!vWT zQosxK`m-+(7FDBp^q}3`582P)$Zd29%b}d}hd!j2Cv)^PJdiN}KHO9LlMBxgx8sip zUG0iKZ}s)v#_G)w7SCj6z0QeVmm!}?7XmroOj&GwFer^{eZD%`vjneOshqHFm#N?D zkuUgWyvXlxEKWt^+`R~~bSQnZCBAzxQ??dt{BGE7p@*aI?caZ_5#J!4Udntpvz)Mz zrLFhs>tn>!vwrqd=y!TAJktH)0rk@OfGdq!DhB=uxT_kXb4#rG6RhGI;bGY7;kEGR ziUs;+dMrM((@}ItVuJw}*K}P0qa8M@M05h$I0O;ZO^Y2I^uO!Fk3Ie25moog-v;`% zvdwIQ({9mh$kB(SKM&=ZaEYIE&QfS&M_xJUl`T4zRnEx2Q2aHsrAJOSRX*q0c<-5F z?&wJ}nmHpQGmYo3XGgwCx$%aUrK*ZZt-r!9v0I}&0^Xsi#;!Rc95T9u12PXwXG3^Gc*H7#g&p(YW#<}00g@;|8w zlb0z&uYunte5vsPYL!!!UtP#P*nv!A)C#9$$fj44h-Qp?i-S#Xu1-zVR@*SKGA?H_ zU&K~e$`h_JYEvsE8cVe^DQ9K%dDKByE;BH|&c`!|l~5yIg%Dt0XL_-4$Bp_QBik~d z@ge@9--Rz+L)=FQ2_ZttF;W{6Q5w2^?khM)QD{Pq$@AEVYOXeI!@Z-GC z1x1SLbzJ!kEwUgE;cObw9yT4Tqc4s+IWrLHOr!~al9hIEoL%bjUh!5n>U&Bc{1LM zU9Kd7*o(-RVFmXF`t(15$pgz%3F-@j&ur@dTp5)L=8K800S1wdE#KXH$Km(N4!#TI z(dkhOsnH!phqXLRm|l0>Uu5&3jf`zRIb^O_)xokD*~oBp&2qG`IONYK9z0xgN_W(Q zFLjYi`aJeG`Q}y--s_83if06xcjo*KTw1Pk$%UoIe*L`(8K;yi8Q2rJ!}-m_6g4xk zmeZRUrHvV$G74`AD?LZTA8d}XxgTGSNEj6|@R=G^F3%Z#6sHYxQxW-*lptqrLU3BF zmuQ@V=pD?)tX;2Ri%j+wa#br-Emp}%)2r;4@nQX};Mo7K(d&pncUA>_j9jn zUDtIdM#%!@3UbMD^M+d6vCLI=lexU~b*kG!o%9x50?#J))+EIWKwycy#c`;W>w!;-xm<0iFOb#fR$umfP-whn6cSgnuoPu zYdptpVk$Uu4N9&(6OnNW zVPXLuva;5|`@k*s9t2<4oP3gTe8H0H-iH|iGqLm|!h@rF-?F$L`5%;@Eq&Ij$uOKg z{=n1kc-sh8;=c%G+YO0JvDIajj$^GlnnR@f_*txJwNRfyo`mhc{QNsPI5>V<7!Kw- z2KY-Arp{MU0l$Wa9Uz{A7Oup9!QbIXAy+rLBRlQF&3zKw){}9ob{$9Ui$HqyTRR7% z`y^CEN>`=hxECaes6V5r?_RwS7rk`78~TG0b!@-7xqW!c*Kp(nQ;X6*F<_q*k5v+i zpY;&Upv?;H_Fi!A6-@vw+=h`pw##@*IeI^-D=bE`cKJ9iqhu{=PfOhL@W%K229NpI z|CrYQyEjv%%Q`8ks)ITkShG=~Q17p8!-i{jT9q3;$7g)W#5pC>h6_oI$RoCwOM!oJ zu&{OkY^7w(B;IJahnu&TohQ|z>FK(bS6o~w?2{vnLUJ;5Y<2GAfqvZV_0c^u`>uZd zq@kUiB4VC3-xEFqwqfn{0V7@k#kGBmx(y%!DMz!Ob#dEmelN_l`pV5VPkAfU3(y79713J_pX4M7D4(YlKa*ix@P2c+~Q}e>` zvom7WYaLKKMHE6L#Tf=jgI#LD;Qs=?Wp^o=?Z z4dj8YF8Ut1JnDbs^U|Nnu((O z8IxYUwgfRoP!X$S$j>`u7mRg;$F(2t1I~|pE%@Zep*PwbSqwe6?AMiRtQY;8429}F z%7U{=tj|E>E3yW3Pus8F6Zb?-u-mD*k($-SJMHXFu{Bk%oxW>h{+Z1%Y<&oo;oiKB zMJlKN#rfFwX^{F9E1W$Q$EY*>#i?X|2Hc6gcF*rhD@I%8^rxvu?s>5=wa{GB`Y8%g z2S+dmq)7YoC;aYHf~u?o$^RKP_U__C^N2=3^g{tT!Klp)r1<0I@W<~BC4JX$`^B4> z{^HDENKXBUFg|uXD7%fceTaFHetGGsYu2WaHv*`o{V`%q&m|9^;LJ(+Z~@9^XFT8< zKk~|mDyV77s@YJSRwAe^t@P4OOSENPk#i?Z@hYf41W+80Mv7Yz6n(?Nw_U{Kt=)*! zDJoBd=T)|@29f=A?dYi;qd8)FpYRI#S<;Wl_2%3`gAFqu@VV-a>9tp=R>lgG~|Q zNXma0Ndm8%opgB5`jIRm^^EuGjIUU?pH3za?(`OV#MG5on|3#QvJ7v{Ey_c!;2bj@ z)(*4K6LW#XJg;GgsemFR1IT^(8ozHoe%E8ij)l-(rnnl9dZJ6aiJ|s!9yH0QFS?VB zH-eO)kV|mQTER70Qjb9up=RKGFe}reC9lWSGVmN|jPz(rsY2F#K%GrJ*NC z@@hpe@l`&$ho7A5p)fw({ql&JvR9`yZ92TdgzZ#0%~eVfBnOgxzyE`(0Qv)}KK8m& z&$Ax%6@>9wqlt<2t zbJb(OM-|VNYg~NZ4 zK54P`M=a{8OuSr-7I>fKfq5H8V~FxZ8bhfGla;BAO-tPI!L1_c~>lPH+EjgPAcl;GjO+cS5L0sz_3y!eItpS4J047Fgo*C>=ynP zC;VRp-v@Iu-}|k%hMjrh*CgP-I7e7~BveLVLjXv1HV(!#|B$wKRLw!1zeCIdZ_B-X zGK)v6k#pwW3&*`Nu2YeEzZ%Djodzn^f%M>zKSRM7rm0kArVTuBvW9U8&i*m3Z+O_l z^DlpKZfmi?e9B%MJr*{6?|&xdS>e{4lPb)v{u}a`~1|$X}D<+?-w0(eoJH>p~xtGoq2J(A1q!a zo6%7k4CGa2X@65W@l?lF`;K^*<&w$yFOG(T+T_IOt!PczZGrgdJ433{%EW-|#yB=; zutnhe+d67&L7?)3nJO@sRc<*L8=&?V+Kn*=+H>|ny6{i^Fa`1xO~Hvk184-NDEoi` zI9Z{M7=&ZMIxcnl$TjUR4k}gqFV45+5sB|w@EDB z+Mb$*Ro&WyDKts0d>`P_CcP>N#ttnOl=>1G?Ap)8A8##4(XIH`muBrQ=8zmPRS?pf zG*mI)1*wgjujNA&vs++bjA{a|!Sy41M&GIVzc`5*HvsX9Yil51wZH%{37hS{u^JlNbE4|*twRMo&>$H56Dp1%itg;mxHgb9->K_}Z z`60IPzvl8&Y&;%uVa*2r*W=M$iPe8ty{fKbsu7KB ztb!I&=GIPCPhLn?o01T-Xn!DfJNMEUh3N;57S!qR3)F@{_+pL>p3R@ZceGJrlKfkA zCWYXV9gMgD>EtjXMc+<7$e3RaBy_V+@{*`}Ab1|yl}p^}UtBv`*V60a@`p9KjeTKo z=me~-U3ZE!WA@9Xm8P$#3<>;+twu zl-R4a8AH9+1$O~uy&1{KCJoY%np4YHn^c<5i%rUX z-X|ZXlkmI)?xiPH^{KBJQIFjyey=tm%?LeY82I_+Vt8i{$s~BURghNUog*v{zxb1` zNSrQ#P=O&;k9^i;|7lE1+qti8JFI#tV`riCdED@=xvLj(yf3{;a-m(2hIHH|QSItP zDdyQ=hPbNJlzOstR0kX##m8t1N-P60M4iRW8$p$seb_&P-@GeT&wn=R_{M6J#;@FyyhRH*U z@;>xuJdrl2H&(lO==^o)@5!XVy`)b20nAo0oF&g3VJ_PSDM0qKAb@L4{-Bo+EA-;Q ze{rNNY|?+#K-nXxGsFZtFr~4Efi*d)uA;u>GF!;&sNTJ0DbE#&sg0HSqaHpV=oV>l z%U~Ew%9mS`r6;1y@^Opp#BI?1Pn`UZu4Ew$Y6{`8Sf7J^bBlfopC?9g|Dg-M81Q*- zXn()-87b@N(VyDY^S;txIK=QPudqT$4?L(GxmWVihX=r-r*?M-44@EE1Hp$Q& zw@!a$lHZ{$4UCvDyboo$20PEUCp_KZ!I1&Xmchh#Y(lq#K*f;6+< zgdfH}DdZ-`lbofmRyH}S(XW(%UqQ=6Yx;fS@74E_Jvt`T$@=ieqc%|Fasw&o%lpO6 zm0IGSOdHx%__#X%i#Byu_MhEtndAKpe2-(Fk6Cz(F*7u(srFY&+ge-nv-@k$oia>% zNu{hvYUIM{3_@73Y?w2@nc(0sBoLUl930#Ld{6DaMmH(XI!8#6TRZ^J<43(YN6&&| zgibIIO2*8F3rFU1ceH}yz6GmGPu0nujT5N3N|dcHbng0SH^^FVoz$2VpbfqrUs9Yf zj8D~}T8j?bR92>T$>?HMUbcrQ_K3u&v`KwZ`eh>m_FOkK6|UDy;+pjE$^dQkrHFUr zW;&H~hMMYSGV`t#sZ><+bcU#r`Lc8U`$w(J{`v9pH^?_xx7E~+oXthimG1Dd#=?c2 zF#Z^Fv4ZdgIRd$MiYov`2oHhQdB~XMd69*1nmSEvYqJ|=zY|arVFO~{Nsk$jZwt7g zql?0M+j10X9+J68R-}bmwL6&UT|p2?XQ+&scQna2FlwK-nQldj1*Dt)G!9zRGq*Za z9cJ&$5v&JK(8zz6peJ-1bFEZg_dzzzG=v+iF1ap96+k0U?2>~G|1k%W_UVx|!vE&` z>0hc#&r*VeDmZ0->B!Fodxi3AA-3P~=$(SZ1Y&XAIdpaLUz|ek2xjAN_o4W|`oggz_`gmOPh_#b9^M?2P>)K`Z3VPk!I}co zyhGOO??g45E9$TIN%HtO3Yipwr5Ez%9Bcphp7+JVPYhy+Jj3W*)tm5!6=&l>O}qUy z%Ah(zQ*^(RSzbw9%)fLn=R>|AVgz5)Om1y9BZi6S#OqsdP>)oDS*5LXwP|4ISX}tV zbjdnVqFarUY@DKL{5UluR4yy>|5GB|56L8+Sc*ql z7}A^qiU-q4cRY!X*N~KAdu>fqb+&EZdIOGod^4XS3^Mn>L?4R=PQ(tHtTDaq&j6H) z%laK-m@|(2%>>f~#?J#Sh*#T$2RV8z^iYI+r^OKA1dyoi0+u-m`*# z?4h(5Nta22n5GwyRH&^!b8d_ z#HNoD5x9@ukm5;X0{n4o;GE^FlJYIii8=0%`Fecy=EYp&CsW=kOzw~1h?L!kqVdn# za|TEA?h~jN#(v_|L>Y^{^$lz&=Fr9oy2ejZ;D?D`yEY7|@*Q`hl!&2l#JV%e^Faz4c9kDIjqBLv8^)o&_~>MY5k>~3DM?k`TcpYr`L(d?s%*lX1o z@;@+{2gf_nT^)<*>6`HCwW7Z`$#aTWB$WP*KZ{{TaHqqq*n;YI!*~CipphXeP4;g%7d)f_BaOtAVzV z(Q~X^JFLF-K`gwxfT4BN4^7BJ@_u1qqU{YPZ#i+7&U%_kJ;%DWjL*YLn@zG|3v1NC z{m-whB^#g;=r-6oFFLn=EfkRX94ko6eRqfD>?b9Cn{d-R=&GyZjIkm7*bwr=T~QPB zHUKrM+~UU1)d?WFWI+O>rDe93W;8{{>7ZMC2`VkEW{_~y`K4ayx;L|Yog?idTK2*3bOBmDCHN?QlgMl(xFZf}aSULS= z2J=57VjLY=J0t0)I-&p)N1o(B_BRrMXfQmtK8pdnh2d5{7JDW9CE8}^op*D}vpQF2IEBzo8`;VP z6CtI$z+^{XKeM#uXx+q9{X6PJgRIGL8!$+0~;sSoEt5zPH0C+5D# z+<1#;+M}xAtWojXy3pcn&M=(%Ij@0Ud@Omvf?JwhQ@m^jn{DLFr5UZiIE}txK#wms zr^qp@wW}W-|GpV4-l2xA5d|?jAMap{{Z8h*<|T^LOodWxmv1AaeS9)u?+FCXV7oep zs%iCOmO`&x_|tjxP_x8yffH}onM5c8AB=#CKT}pd7iBOTP!RNEr}qDDYcvMXNcDIg1&yNOvfeLzv~pgAFj|kXJ~D1tw&~% zi#5YEw>W2H|y3>N~ldkuh5#h$P#o6Y4 z2>5ilLQcP^gMpI|p5Gc~8J7p^=o`~RwR0G|dNuG?59`X6Arm(^U^)eI4NZ7KlM)|I zO5=!t?Qf-;FlT{?t+_NL`I@xm#q}aH=8_Uswza}`F+xDbARvj4q4_&`*0?fg?6}i& zymG61R#x3+1TT|A8|zlPzK1!MR=FNt&wJa~OX*A4y`jtwe(i#p0idlDYu&vfo(|wA zU{!hjIANBOBMCRhx$Z{3O~Z>^-HV?m12oNv_bQ(~rQD7G!6w~ReKWgsh_&|%&YB1^J2BENVk95Y zk=K#rFJ1<)gS51<Ks5J1o z;UM^u(r0wB@grTJIRYqd%JtEf5tX1~*k~6fb6Db1>dl{WG3hi?^d#;5Ds+H4G?%K} zw74(jkO2+>>IZYZ@{!S;HJWotK*82)UV>&X7I+hmn}fLOQs~1yH}Y#ck+f2M{Pj-i z{j~!Iz_oNWxn4%l12aL_)?Ds{zc!y1X18B3FnRNISxw6EJXKc5`y~QBsD$NhiE`Z0 z%Nnbb;hve!F^3-9!i3oopZGL3eA7OCor1dA(^3&ADuyR2Go7KP44vy2m@*uT5!na# zYL5&Re6i;f1)deU?jEY?)q4>FHRUc=7&t_NRbG7jHfIR!^Y!?gT8icBt8lEEq>m!( z$FA7;@Y{(6tMg%|&lH}3rd&Yq=*89TgWvMcN8l%JRX&+iw8`OHLTCahw11Jy3)9`j zJUw^~RfxVVw~|>#ac%HmuejytZ~3jvr7m#Ej1S7R@p(gY?KFQ6FU4=i0U3u}p1TfOYscs1*NkSBC0#zdeI-YLUBP1KS-$tl+3wCz-wUu)_12q^ZLS`Hob z8{&uyHG!O~z|5s{u1OhOW;JQTOjf0_}R=A@gS;h=L&oU=$N; zcn%VNUy7MX0LU#r?&c>at(!bPuwkkZZq*Z68!#W5ccJS6ClMXr$_Jq-&v%RyjVzl; zf7=xS+^@+6W;*DS_v2XYIY(Bxv~qLVfcH64CLjNKW2on%T$V1_lO<@yu+ZWPXN)NhN#s2jUJe3e0fcN6pASs}J|&tSR4Q~QYP>FSB1QFe+>=Dun@SoaGK>kw*Na}pN#KxV96%c&j}tv#E!38>bRI@ha^>Sf|3 zx!1kvI!?a7zdCI{`LI@bx3{W{PCZm`!0-HL+Kn%?w7no+HN*c=;nn1ifJ>4**T5|M zdgR@|_kLY(A5pdz&4o>7P@&;G^|OY|*@Hv5D=vG3dpuvq(vv4|%ZR|&Wm;zMSW@hf zs>E^2*Nh-s9yi?z+7{29W?88Q8nUTyv=hh}ZJ29EyBW&k+RIIb$36BlIQae!&+OY{ zc@g^^pifZPOnzI`*D^)bEM7sutYWmEUt)-W+3lguK)+Ji`=8#X{UYyjqql}5C zyc2Dd{{3M>c@)7@U(tt}dH0)r!C!z)3pJrO7A5)vbq?q}Tp8 z#|*Tf3o|@Qi=fZMr-}E_TDt+MN(u>90o&vt5d&3{AFcRzZx*(33a4G?;a5SITZFaw zj>+IsAvlv{4vi+f#O_~dzyLJFo0V@?Q0|Y?yQpo0B=T1@VeOBy(dlnos%4{S=LWV>7#8zD1s4^vXxY&hG$Mv|#omStLL7^dD zy@)<#q>acnq1MIM>_7TD~2rbq=t! z6H66XRxPu9mjU=7`QOw@JR`v@XOiRz@Sx zNZ;1mPjW$m;%!tq{?*$zvy)T>n8dUD-*#cV{q~<^2ev+D7)EM~aI6m9VZG)`8|;!U za;EAG?aL!9`X6Xz)aT*~Jdt|@&E~D}C3AP$tc~X|Y9k_^o2{J8hjBjS2oo%)U~oUV zVe7_IykPVGe6R3f*RZHmGdz?^6afN;n4BYE&;r%oDyP=M$sQj=+wi+uBAVXmi|D!V zE42^?*S|O~>tqycVIJDh;^f}Qp0pAn83PLCj>^!9oK-DV>d5LQWx*D5#qLaY$=Eg4 zTQ+U~EkFFV?S6EW(v8Qq7-byzv2a3cXK`cSCNppLaDYuBrh+wuRbT#|twExw0s?k@ zsJQ>|6&EWX{>t#*gz|r0epFhDqGIuH^Uky(Tl$kg+i3JrHNPwQ&ZVEhqae~#vcl+&`Ba} z4sy-9P3`M@xw48| z)$<)0__SZXJKj~W!PAW8kFojrK7Gnm%6{CU`syh5#EabT<{rnZnVhg-{7#&CG^|-Z&D)qOmbyU`Ka)-gt(t)= z*SmQ8nVlEwlZO?MubWobhvJ!NH+@|s#0&8n(;z<;7`weHcq0~0g z`a;MoZo`#8OLelcZga>To=HGYP%?h97r~g`Gi5GYxfS`$D(`SrW~W5Gc#Q9BjqK!n zQi6h61gqe#*4RRm;GSYsqD*G3;_ECvWGl4pz~Nu}rs7ekg1~X%^>NF|Uz`Rn!&`CV zrYiJ4lA;xZfkuicHwSRfErh|_?Qd1K)c6i^76x_nYaqHS-HikL?9V5$Wh)>9D|Pm% z&2h+ib8u!qh1st)>Xq2W1IqI(v6b3UPU6hG^>(l)v)*YN>BXm}1U~q4C*G z`&F0>eq$Q9_YQ0>1Zj|8wCK)Ux+Fh(7S!){y4Cl*mLg=0toH8+2G8mT77k?3i;R44*M&y#Wt(sw_4a9Y-H)7Rr*saX_y8r^4k<88r zu87d<=R)An!OGC}B4oXp2*(e!T8UqTda1#UEGkBR8Q@oJ=5^8q z$V0#08^O+@2raN`!eAczlPU(TRSC!WzFH~s6vwQ&z9LZI0M9%h^|)QJhYNl)8ka90p#3HiHj@DqK88Z4n;(Q~>8umH zWtvZDpzAQ?y?Mfti4(EBqzYo?9YhUr8v>A)x$F8Uo;j5~99krVhRb!TEfp4Ltpga zQY`SI?2x(%{u1N`MheaMtMdm%yBcD~^HWG01q!Pay7ygg4eIFe4xy;%y5M-o0}-8N zyY&i%-nsZQbaPBzKR>CW#T#=JlC&gWHL2kC)5(v!T)rhv=v z=@r{0DbJ9(lJo>~*B#9fWM5`y!HV!-9On@!n}5+Xd;+`ZN4@pq&yEvYm|;<6siFw!(Kgo~@A`|aoC&n6R>E~MU0b?5Rm8e*zo=M{@aPK@l1+Q$-}Vhl28T8 zh)$qx#vDjd7G`;Q@(;-r-SZ3FhrwC+EI+RsBVp*5LHvXh4J5WxMs)q41#a(vcY+>& zOD`J(h_wC0Ojitaj)$_8D(+SNa-lR+qpWhadC7czO`x6TV$Qx46S^QcPIaB~t`MR= z&=-tGo?f|EDY#rH=L$O}bEFL%?TTy>%YqHV)!e0Jy)vFI7;UZN64Z3MikUVR{`rVr z6#iXe^jndoFFn9G2Arrl6<1IiTe~t}f8AY;*vhN#IZ0iqT46cpL)wFB5tQoVVG79# z@|p>h6){>>bo@hWJx+Wn0h!yUROcx!@v%tji=l^BFRcChV)&;wFtR9R(w^Pzl;y$7 zmM9rf_H=0E+U5E+bu;!Mj?3i34Ih{`y#Ys9{p&g$yC~JGX`YOBtFbR^?UABJ`aG9K zehtg$jr6)s*3v|3m?R$5s{1m%-$603!GGksUR2S%fV{54sFAd@T6433-7iTw&sC>r zpu5f!H=U68J&XNG9aTltmapM28&5xvG51tL_qck}rKAz=#wM+28OLE6j~MCvFRfkQ zJR}%={yf#-kUOJeQ6FaGl1^=4pO`)mVSys}SYI##E}hq#pZA9o?F`n;5#%1&D^<@0 zvYDiYJJYt*n{Y5i4mMHBE6?Vjk)So!NKx6U1GSc6faR}|W!mQ27aF=e8eYr_Zi%Y* zU8gp7qc^YF%lA%yyL}N(XbAD>uHDC@pSVd$p4-fsT;syncAu{$TrH(H)eZ2cHK}BhQnjJMm)ff`H@&Lirp`pHD#E+z;q!=twnvG^>t-7Rh zzp7AR(25`Vs%`wXXxOH-G)2a#({Y_LIlsXdZJf6{K8j!P7Y7TkY9p8RCkBRlqN}*Z z)#Qg-*KV7i0d|=o5er@)Yl_VXnD;1t8>;;5Ppk|_dK2Vnh=iV(*w{{5R3mb7#ky8b zN2wD4EMKr3R&u|$*n3J@w%a-G{Bl>S&X$Iw*%cqpW(Th|zdVVeUsg-SELTNk9$vlM zBIzD~E5yL{^+f8OTS=pxGAzY?*yWPXz1FINEcra5iz-<4Xp96Z$^4SAa0O#EJPy){2e%hQhV;BX!KqX>AiZ0bSdv|&cE zhQ@9yH|+~%R4E$az+t=YinzLkrT_6BjTBlNWq2}CLmAMLxR0|g#vw7AR^^&sKt`C8x02<^m za++I7Btruj54SQLC<3u9MeT*>2}-uKE&azgbdnXmT+@`%SQ1>}J#dNqa#)a3YcJno6_7n1y;^I0N+3y5- zQyD#J7mQk&Qk!q-Hy>^J{BzKMEt-t+i^KfW+qhco{6v%9cO&MS8V$%2B=acY(u8b< zv6CVzMwLx9rl=oV^uUZ{r=E6nZ?CNhPKifKF(^P@KpL(382~YI#Xq{3q1DYSfwN%O z1b|T&475!cvXCArJzYXQc-CV!Tt0d4d63(ylADvJVExRh32SLQG_m8@P8Xv^yXiy2 z`4?SuXg@0>~I&vTT^bvUW~HfKU$tQ0-Y$J_e$z>x+AYSWfe1=R4?4OzQdz& z#_}c8)3iMw&$aSLZl>moRu*cATx5h6R8bwcas;AJ2yl%p6bqW)-Nj_=|IzuMTc-cB z_h&d>N{ci#RD2^CH@B=TYE7k{mf`qsgD!JYg-{&B$}25#YUzvq62XOs}X44&gL&q{|8)0e zc+$IYvBJ&*Zla2RPgrE^ki}BO_`ZJ9^JUyeUu6>Cqm4dmnugz+jDTSez z0)m)n433L#X*GR+1b^3^X*x4OqS{jy9Gk4qpNlok+6O05(1WmE9~WOffO+om(RcdL zV_5t!E&H>~kMM3}7*yKpC^x|$BnSULn?(PwZ(UI+DJln7O|hd2_BM29hNEHe zj8ehqbnh(<1P~xFrwHL?Tn?Lg`g{s**1^ZFd_p5PT9P@bV&;O;8FtdBHKk01gC$$| zo4uQ33L1{QUK9$ix;RXvXU^S=K)+xH`t{r*XBE1CYE9mCzN-a~M5K;dJdavN4S}KgCQDVK_0&O6s1gkt4oSix$%mMN_$O za-W6&92#E^TDH+VtOqBgjxSSY9ae3`rMK--%LJ zUdmeSYCUeBMTAy_h(h}7Pz?=M*a^?mms0#0*85z_8lryJCAGPBrddDCr1@AHwyHI* zEm7o8lLg2uh&$E`p{ZnGI11F2l)qZPqWkp!jJpe2ovmvc{8riCgWtdxUat|XbR9wkW%@R!B8?RlXI6dzKn z-*v5HrM0;}g&J!S_kPN37dMfiJ9KT~^96)q6~z=ZBIngp=Mkg1?sJF|>By}axO4zh zUi9JfFQfaX2czR%_&06mip{%CP59*YlpfgZ_h^@yA4-f0fe(LY>rR0c8dup@Y!nQb z$U)crB`BDrWu;BKe3ItP!m8ov*32Qm=daScP!RS4sO;E7bh7pne~3bJ{;sjoN6Y_~ zy@@l;qCpU~+O$8{>YYQ*KSono{zIL8|8<~Vaz_!PqBnkpNOewi0*xa8F0c#kTn-Pv zqN1V;l+NNya>7d0xtp;kX%F8>P0s2TFN(jwmNaVvZSL?qBE(wz_bVE=o$0~WfjTXH zTM}VbTO%alOaZd)2_wcPPElx*RCluPJFQu!H>e=>%`2B;_Cl(H;l3acZhUT!JDcQd z?N_9;{~iyPcM{w29I|GNp-KT9y^|v?PdP_=?#vr7M%jJ9q#S zdvyKgxUGILgZd9x*3GA_&Zg$HA#47F(#|?ge+x}DzVqZ8QFk{?j|GqUK>UD#ODz%4 zw2*+V^M3NPE8rUDQ6tvDAtON{%~}ko8OocL1aXYhK!FogLsFEJ_XTa?4}V5Pnlxm3 zX%?bUZMKcph~0-B9zE-X2rZ@W1r996bB4o?v7I&`iDCcJ0q=FrC>eQoh+ZnUvcGDQ z1B~l_LJZVhBS&1@QN`8+KfVpd@Ce|o%Z{V;xvDN2m^_&#??G7{l-EaQEPI?9!d1OMeE=Dl8@7`~oSQBBI14Q^%%B1^^#g z=s;3FILvPg^XLq_4rS(7-|c9}92ru=J>(uI_LV(PDPYS(?AW-Lw_J}CF+T)#Q5qm8|7*gbE|xK#Rl-=yFs zXSyUV7B0J7Qc<+II^smV9p3BDCM7SW*g!GSI_?-!$xo~Z`bvhQXHA&C@TUr)v*rLW zkn_r^;Z&&klMuqD>`rcSy#nD9YmV%UA`_KG55e>;s>yBC=7{|mxUX|Q3g=7n=|B5z zDB4@S_fLM1B3MS(@G-u080Ck;^>+%=>W5xcicHgKXP%4h=DGDXtTtG=YVYT7*g)ZH z{rNo01D}KMsa5MUl?A-3GO$qufYtOqE&4UQ>~&G3cBe*|VjGD1c{3{m&NFJuEOIa&Hhmr5qqojl%ZqGB;d( ztN0u1p-?@e)V65%!J(eZXoR*qV*Q3yYd4Dcb<_Y^mFeV6dH;}lYs$V?b71fsBOj}1 z44q$k5tl~^-wrwMnw4`@Qwx^4mJ&_-5gVcZ4D4yOGz|7%a>0sM_iv~s^j+Cz9~PX_ zYliv+zMSFrc;ka6$B(pqO)>T8HjDc^qy% zLaNuT1PqzvQC0m%SUv^Yt0+EiPI(rb1X$3fj`6(3*(#|GOp(GGoz$YSZ9*~uvGmWt zOwT0Mq+@eg8L;MJ|Mw6t$9a@ZTWhH32#?&B`NM8-v+n{_$r1`0AJalQ9Yk+xKY!1@ zk@rGnJ@#i>@*HB%8l2z{!FHqj<1od%bE5?>b8Fm*w2iDV<#*@5c%S@OGZyWzU@9xT z?)FI($s8Xw#s#ziJKxcPmiac0X6=lD?Z+IQUz4ZxZmRBG@}qY7y#@zszU@;9RwX&i zRDLiH&W789S#jsW!~3HAc{siwIi}jAyvp^D`rY~Yv7z%jQx2`*5Q5BVtU(->G*Rex z^YT9^37_{d!49oq@jZ9VAYKRkXn7wqSJKrO`xyN!;RjLy_OIv(GnUt2aL$CqzGS0M zw3vYP2cG)mq(&6#iZ*@R`oYMVMZ;niVCTAWotU;dBKc~fAtu#f$$*rpn)L;La&Wyi z=<}o2tm1wf7EZkFB*eFR%C)PMNz^ZvC<=vo9iAHI$=jXa6$x zo{7AzjX#f_l@AcskDP@fP)G~9K}lUV`PjA6Bqty3UjzdSm}$!<(J?^oqCq=#W(Nwq z$!2O%L#U^8_Nr3}NRZ`jLZbVIKn6q71uMiC*17)UZaR(y=+g|26*e!E#H` z84gHuW7tPb)xyn|sUYrLC1*5={*rhcE*m_s6m+DhQYBqMLZB zb0+u19!rQua7{JSo;1SM z?4)caoQkbyEk!392kel{2DQ=aGg`&GXWEpu(~xxuMk#o5GGuK1ns?m>`O{butfz8E zGyl0VHqRzvtf#rRVsM+?bd9LqRFd|(m*yhw;X=Gtguu28@q0;0hvr4 z|5>ZXjznJ>9^HTYPm25hKL~~5Hev-eNh9SDx}Aln;?TQ(wJmuYc7iGz8Vsvaa}G95 zT>nQ|v>;RGzC_f@AHJJ?^Rik&nQ&CO!@j7wO4tXbBOg+ z5XSo#WCEA{7@Iv)oJG7wk3C9uts^#+*b~GkR4gZ#3DNvd_+}k5dsxwFWq{i1`fHT; zfHKyTb*C?(G}u;vyiCtIa#!myoNzVpruB8+dPnR~?VN|S0G4Jq{RSv${Amdwx)dN# z$%g{?;JE&m&DZ}?7_RHcNn=SXFTH!F>_-qAR%s`A7zgYK)Wl zzri+vxvN-UW^s!mBFCmsPe+YOoEcgofKXRW#~OCBmk?dvurV-bv+PU5*yDHakGXB) zm${rG$_{#xE1Mo`{O2Q{MTbzh=O zgCaw{=9an7-;L$J2WYo2*YQ{VSWQY{#$bPYRMFbP6!b(()VN`5a?L6IVUSG69ty`O3sjJU{)r3Q@*RC^YFnl zTzx{j7K^kRYN(i&&pKY{MUg7<&VJ>7Ch=;WE|q?Q?$$B?;UeZC@o}ekJgRNOX7uo? zZAa>4^21$;?)_gJ^80uBDW^Y#{r}>8=KG8D2Q5DSbnBVR(}xAD_jyZ>Z{8z^(Tzg; zYrU%52ksk{yBfv6IP6-!zTJOuAgr5rD}+C>df^E1%j&4J_UjHBtWdg+()ZAsZs(f; z@Y*KjL96%yo2?#VeP}jV-9_L$+(@9yD`}vxcctl2bIWl zPg9e>A0}GIRslKbSaM3+4UU0uaTuNQp3ykVgG&xp@+`70rtM5ElkRPoPi-6Jt94S* zDsJzU^?Gjhd|=5t^IgWMS|rB(%qV?{`o1R#V3rLuLN#@bAa>#WP(Y~CC2mrfUsmW3Z&YCC< zfJa`oKKtlNcg00ZU$S9Cb{&RGkjBI|*1Rsy(k;elfg*$1wRU_Te{=~c>cKe$=E3(`zhhZ0=>My^Gr`+4g4rf2v6)ix zkinAAVm%U^`|WCmR&kuFm0x2HUio`>e7FYzt|ZFJTSFKw{yh8NC6othIAJEd{{L$4 ztb^L>-hEHsQc5XOq_kKm5-46EP>L2OXbD!Vl%fHGOYy!yk>WvuQy{@zLh$15?iSqL z>CNxlnRDhh=ggTickawRcjoNB_TH1M?CiB?uk}3Z`F_4IJ<1rVkrHM^E*hBVtYt9{(un6FGxiCc9Ageoow6I34!*aa`ujGmRFxmMpY1QsFZo7vld?&Mc1MHPsmrjU{UC_n8r*`A48Q4l}qta z5b+AmEBxWCR%7(Ao-`p4CJhxdV8C8Kmhdrz9rT;NoP7VcrvDH2NN5J%(ZP+}O=md; zR|n7~_TdiR61bp+|cfF3o{H`r@$jBQ~#)}__!gtp2CH&*1$^Nm8~#8 z5OWw>E#e@z+Jv%w^RP$g<@$R=%Iu`7Omhl+=(YMP%IG(gwxF^p1=q}neUGe-``X5W zx-@??Z0)j9Jy3gm8i_z|9w9``Z*nCw9yPN$BuPy0u<<-T?Ia1va)#@6Ob$wMuXUwz zP*?WwkYe#rvYYSOR4S)mPavZiNPGQ((Vq4q8!H{-Kd)7j-BF3RJ$hA5-S#^Qdt|{W z8oYEB(mz3Nq=}@BF#bIW--qfFkC|L0>S*WJ@>)x@IcyY|_;=qmTG_eZI1HwF|H0{B z;-vhq_AnhAx`+*4OwE6N&nZ^_SbU?@fHQsfn8DuF&9KttJ+lK~O6XAydJ@edr4&-j+W&e6WM7PsV?6 zR{A?{UaKAAd>}bt{{A(yxYOJf~mGC(*)WBD* z&5E1Tf)*8T6IGfd1-yQ`il5v~yX1DaTZzoF%!c%2G*|)S&Wmd*Itv2&9 zY?}_<+4@Y30-W73e2g3WZ=b(ELvPO)=wo~$pAH;)W#K4WRhF)CtAr}HuXQ_$#>c8L zG6&m_t$eiF6S%f$iIp9kA%k8@IlnUGo=TaRCN=m{!g-wH&2At6gGx5Ok)+Dt#pCT` z`OcF=DV94O0qzTLF;$z=FzuSzn05g?9~z=yqVjSN#4&Y%f_EhYh)0r$in=a1ZpJrn zqdz^B2bC&Y5uUNt5i?ZY$|Nnaz`-%49X)MRIm0%uH(hwG zAO5or|Ay5q)94QrCPONd=JXn1|6oDg;@2mTye846okEe$!ig{n5nGP4MTtq2M($M6 zPInJy{(}Zz#p=J*(7a{@LC-quXT0C=^Fvy=Rj>^CL3%g9Wu)_e@Rk2wr?j1p9Jdg= z);`IOjkYxDFpCxTI6ifT&oElIH@}wI!5Y)uh`3J2;6W_USvQQSVe%%MDEi`jTn@@u6W(5xSF%)W9I1CrJp8~nnfY@IUPdG&ERbP; zgI50)!glrthjlWhy_xMZ$kZ!cqK~(A1!wz3n9Ytw%{dR=U4kAw7u_C9$eS>T$j%(# z2waY4tM*pd-ZdWNfEd8VMZ+ICPwg<}--)=+Ls4Q5v+=bH7E=wVj|=aOBOpG~;>fGO ztJZulB90U0!`G-wz{je>E){}}u$eN-8ba($iYIocCW?MfEPU~JOwT0SSO4{1Y=Fvs zUhF63Y@v>hq5&@OTI+SVO2tbwqYhx{@uyqh{QHyne%oxHi!;@qGJSGs6DwjFdO&1n zBGfhbB%>GD*sCgBSzOa9oofW4N@w#xHp6usAk zR%*mGy%M_L{wI&mgX0qJg0h+9r*#nJ){?>?7Mna0bl}h|?hfTFt?*`g@?21DbZsk=mtckby-${UZ=8mUKIxWKK^;ptms;Sl^+E{q3@A_R zmMV9&>);?nC!Z^A_H?_#dPc)0cH6vAz9bq?^Er7F2c45@rm5Rv=-tQG2)Y=ftTXj) zryjLPx}3w^(|&8{_3MH9?RgDX;n7^jy*;ymQb`q%hN+TB0=xCbo!hJfTbVi%qVCX5 z{cQ`-2ge$ZYw*_xsXK96A$e{2QZgpe(B)CyfwsL4v4S)3rCTKKXSLsb zAO8JKdGI9lY}@tjTETYTXHBoiaqsJ_xJgt$x@)Vv&VpR!XX@Oh1B1?_&ZLw;%`2ce zvVOXTfgSMTaW@An_P+aFI+A8~*vnJfx#e6S1u{Gx4xxIL%p+a%m5IIbr+WXz=hG}D zjXq>M%rk^Gk@Q%$$IS)hEk6G_cwJa47kIQ=0Np_}_S}EQu6PJCvtiva6kMu=WZ3u1RU}zW8a7$hNmUl+3(4jci0V+u7)Fyr6Olv=r9_M&U&89J?ru2a zJ9g-sJ!Uv0n$(_B`1>gb=dao3NMl28(J4oLrLgXw+po;p@|cT1t`Lr79)IpIS;WD4 zMMd@rn@cqlo%^=OS~HoZrw^?>>2#dPVn)lF|8MQ7gbX69r zSksvlg8+C`S?_Q&X$Yi40W?pU_|#o5Z#qJ4eAr&42w3Y`Y*c>)!zlvtZJC7G=n43E zc1}n!*h-Vdu4wLTNH#;kx7c0Q_c{LVk)Mik06L#qL&Z4?OygDZ9I}BY1ikvE^wlMP{Ky&P6fo+8THCVk_YlWkx|{a6Xq3dc(MvOrI)6~Fj` zL<(-vdC26f@oei@Ps5v0E?JLO?uPVNsLa2%#wn0WQF4fP+uTkmj$mry0&s!_~F5*U%^oez2w+Hw;Ld-3Y4&aRozIJY#F zOz!!iUeZWt1QGE`GyIhhmY&n~9S3XDQCr-ZpEEbBXhrAkth3e=bv!gfW6sav-Ma6f zpvF9F#mBwEzv@T2J3bf^ncWd*KKW8MA|K>95Mh(QwIxfCs^s*g+m$&Rw(4qFz%1S~ zqR5Ep!dmUg>z!d1_hrME_ib+Mv?0lbwV{%N=b}BM@1wVuz}H86D+AHEci&(XV;%im zZ0nJJ{J=!gmxXZR(NHn9`OknH)V^|6dGDFxBNmxu(Nj1n-z<|5DjflyKn*1lH+4FO zmWFP1;&9=1G_tCuQE{bu;ZtSHGUwE*2=03Eb+3#Tk+{kYd+JB@!fo7f1RA3nh!jZkLA= zzaETVQHE`yw5;047r1o!Iz+;S$9~9!DFVv9GUTtxF59Q6V5uo)sfMKzJ!z@IZC3jCMuC(#7QoMvmOnjWiNqseCq zjlWtdIE%vU=b56ZYHLM2sP@R&0BG}RYR};QCriC)Pv`XB=bRFKMRU`H14po?TpW=a z9+sfjFqgPzOGh&@1`eAJ_YV3^EgRuXA3gjRtvNf$`8>G!V%fXfE^6h)Mc+}#Vduxb z!`wb6jQ5^gzj5#~uIpS_JFk5=0Y&X#hB>RM?IoSB34K+y!^*a6c9UK+h7nDM9N}zT z9T5}pX=|jpcfqaLauako`Ea>jm4y1y9A}iDgN5U1VBJT=1zF@s#DPC38k>6jzM^MV zuTr+dd0qf)(kk|s zqc#XNHWCG*+jGn|r%e(B8I9Gz1JT|_hlX)@ktm(jM4ain-o zb`g6P_~>@=s6}c`(m<>vSl3h#9UZ@Ma2_95A)kP;U?K)~y;z@aMTLGj-@i$z5;_bn z-0KWRZd^BA9_k(TX)jOe8>1;SdPI05}QFE{Uj`p z%$}RK9jnx#~ zT$Rdvzd>N0z5%*1EBe_i)={A{r^;}fZ(08G=~GSs3*nOron}T_5MM0T0Xd>)UPmGb zU;KmpZnefCA?cKj*_D>z!%<8X8I9AN13@(4L{ZnQu_J{^KM(;^=nEZ0uYbgnQo`oY zYMIIwp*7frHy~%@F}D6tB0-yG(jr&2Zu_04s2$FT7`VIaRuI_7m=7R&?x$bEfj&_u zkPhgsyowp9pPUt675DYyB;P_$d>a^W_HIxtOSevZ&gP?&YQ9k*aG{*cr_p;lv#-;1 zNWSlbW1-fB{u^tWRfeiAWt(wn@Y1ePG2`XiRgFu28GbH0tj_W<{LS3*lVt)OiOKgH zb80y9F)Qk$T{AyQ^LP=hGA+oOQZ$r9%|Q@j$vdG+>P;k`NJT&Too;qW&4t`ea&%Uf zk27uY809(u458wPs99v==3LFdyLQ)QzA}2$GCufs?eZfYe|h~PCu+LiTQ*nhTCQm( znko}~)hZy#CW}uqA^74?UQw7Wjo--ZeDH83Umll)n;A2pqrQ}4uv%h%mMP=dq&XLJ z8FP@7LSJ9-?X!R&U7xy{j(wD#P4ZWV5^fht*~%2H7YVTm6$C5#dyTsqUzv`VYvp3T zguxJqLBFMjz^&SQyZq0c=lalSaBYC0#WAP-r(*IQ$YsP#GbGTo6iQVv4h4rO?1OS( z+zK60jZ{9#0yCk?3x99~hI!Vf&AneW0oscA6#9E&2s%Uq#pE*;7|q%^XI#dpW8=Kz zs7T#0k7aDqmk|@GV|`-3?SjPkYq~HjGmB5VcM4ahMekss&YRd?4JPJk4YAPWRc?br z^LNYH>4bwyg{tJ1Js7$c)q(_{bm?BS6qK0?MVO3(C8UMKju*ijZryV|C zE^93r00}3Eoumi$l9knvf32^;V$>vC#1AS~BUxC+Yr>TaM#RVlPxOPL+Jrvswb`k+ zG9!5u^Hbuzl43R*#6@p{{crh%&8%}huaJUEf#FY{YD&i60qEwkK!xv{ZYo=P>e?7Jca1 z6@n@3y^8>Xb~c(KD@{Q+g9qdeo6?1laYH}LTe4WNmjCj%@xQv_tpvw6m7FSjWA@I^ z4<9<4Ji7KfU?Q1zU2sZOtuGG?gVp*O>?(f8P`F*5cCnQXq6xqQ*gCh?V4y=TjX#{K{K%NRlD*z=0n5*Px*}ndjp;=1135hSH>&&;FNX9 zZ(ax73wiYP_~)BrPa-+cwNIX4teK!7fyg!`L~#`uvUmLxsiqi4vz##?c!sjcHKE9$(1wUJUGZ1VA$pt-* z@!$;U;$s!!sp{#v+*sVY_~Absy(MqlD0Dgvp>^n_%T@GlOl!ruIX)bQLv(RQG473PX)>^z}D#VKYNBQ`d!fEQzz|& z{4C+xa;fW?DvMC@uQd{Q76sck%6Brp&+}G`@Il_&lL9Y=kEWbhH>R0-6mz6{b?CD6 zV77TAIzrYLgKM3(JS!r^s9#zkcgTMfh^3ZfB4n;TCQ%?9WRO~Zmi${JORjVFY4 z9P+Yb@a$Av^D>0^UcSNVbyGzV`5LA$NyI8c%TiZxJ=|?95bx~8 z_Nw2mj=FP}qDFAv80UuiNuLq9qhC@7r#?ZhJ7;hhFO(tAwo_`AXH12 zk2=@V5n1vQW}tm>=eQ%wg9k-|Lk3LNLeS_B4F#pYJWRD|SyHr~o{g93LuI;=?XRc4 z^j67PUOP00)b@W{m4?4=mRUaN$oIN}1$Y@Uj|$MS zuChcOFQRA}DX;`cY3McWhPTKJ*(Q7BSrhk@9_tEf@)}D* zM{V_w%&V+g_uma#sZ<`rbd$Pnqe*d!jcleC)avi^?F-F4P!^KK`+CBrwzWpl1r=NC z!#v%`ei6e?gHyMJKia>wKlc%kH&oGXVQTq7l|5}%$x!*aj(jP~k2z;{Dp=d%Fxise z%-TyhB;TF+p-=dyhW1J}lImVawMuGX1+b>&%09mGPVDx=TQ-&1fRxmel+;)LJAIah zn8oJnMH_!j{nV`^k{8`;729lVJX*DJ)Zg|m#K8un?tgHEX5WrPel}=+>C=89AAiJu zrAl<>KvhtUUMN%q0WOZMd2p~p6!kU=ah}6Y`)~3PA77$+-zKOrOHlncu(EvGU#hci{3>U#c2*Rwsu|qM_^T?mv?Zm~Q*tCh;BI zwa@w^|N9+ldwSQsArh=HNszhrLRpwdio*yr}*_#?1nu0m+ zHk?;)wDa$eL+-yjdT@$RcZ3*4vn@>~jn5(F!&((LA|!NCLzzbh8=FS<#LCzrPrwnS zb%lrYXn{JD_P~`_+CG-hUGjZ|ZqL1If=4ajD&q;CgY30TbduML$xTsu;xxba8Fi2H zf=nI{RrORDgaTbIlJD~zsH7L;YR(Lpz71u?XDNEteG|t=GvpeDRI78Iy>rlanEA}X zt0CbQ7gbJ!A3ZoGr6nCZ0caEM&SOE6J4SXhU?b%2JKRRW!N12L0|5{@0WH#Jv!OW& zAVmnSY~R@Gb4yqS_tTy<#lEh)FkGH{ekhfQ4Y}?6aZ%}aUyLv@#fwTld(WXTGAl5C zWW-v|S>NrA_gXwXK1(I*N5SwyO{aBIiIDZFJ7w*3dst?~xXdIa0$?s@sa6|nfM`5| zFOrLg`draG8{nspy^35|C9gxDff`Rt#`Iq9?=Ri>mK+U6`Z^F5`!%SPaRqFp6-^rI z3a&NkMxm0tT*%3(O|P8TF~36re)c0XA&J~51M1$mH=)~JxJ}nnc9DSUh7MPOYF!fj zZF6@gS^Xpta1d+NLxNsn;FH1RGfYlKZ7X;8fw$`GM6ik+;ygUpxqAxrt_)gZd<~@w zM4FXZjnXqVxwom6SLo*DjGn{~&{VKcQ$~OtnS~(|ws={!th-qvV}wm?i}xFh3eovg z?SUZmh?a&h%ksz{A{N57>^ok!eJ=8{rd^GpAzrbSh`6bkB~Z@W?w!)YuZwT090(`D zt(rBeLd9(i2(~x;VDvI}h#sWHDV$vypL@SUc~BYiAUn-gs)Kjd;PU_n@u7c09a?5F ze^w3z?n>cfa$qTEI`mo5X(BA1_x&cGI(7RL|J9H74t-A#wff03sZ4|1?8?st0d_Rz z@cSOm5Ycap&$sM?UUyf!jH{_{7NmK!#|z|3W0fHu7%i3AF`=z@=aI4Y7MLzHx7Sl= zRAP1&wN7;dpnaRMniO?d-2E zs{G0Pz7ZDO#39}ED{wLJ+Yj6%AX-u1X#t3+)98kYdN2HDw*PNGss6_W3M)C(>eJ{| z>PU<}2}52^dvnCuRw7^{l|>VCe86=7J2ex!865V7ErAr(>A}ym)eq|Qw%=~ne9q9O za2v~J6kiroqGk-KB8R9SZ-&5eajW1-(k>X4gdF^jLHa%uqgHK5MtjY~tSiseZ^QGxvLd++x}ZKR3R~q#9lDK2I**uBlz8$IH`v zPP2f@Pb~awC#a%GMnA6*O+i8ytnehO%awh(Bt{Q59nBmf7enUNjH0h9#9NT<7Y~Q@ zmSe;n3D>_&Rct|88CQn^;{~IPSmF?C1YvNmKSTU;&2b8KQX`O#q1uC{9s2?XxxY!9 zaE-o-YR^E#@k4ii)swy`!?~E1(wEYgBn3Rdx-&GBlG+<#7>Yj^#Ld)Ll&7kOW^yDg zc$v0F#eKf&jQto*(6`|#6rUz8LPU#5=GXDV1a`_7@iD(2o(~rXLZWFwF~e)+bHV$s zySimo)d0{UmKZ*4-WMy_xTX4z%#w5oa!#UZm?LNH;!|G#vEfoq_;iaTQ`yyqk%_|| z`P=0pLoi6!sb(zcqp6i*SZ{M^G>}I$s?=KIi7Ndm)1Bc|lL|28(fqditq^(E z;`v@K3WTy5wSP~-`jE<%Sq7#jz3#YC!K^>eb(;IdU*$beMT6hg0GK}8=xV#8bmTY}o)HgSjDw;$VOQT^7fC@meF%xgwR>e4 z5A}F5O^TCF*)_eQHh}(l4V?b?r{;*yr!|^zG-MA8xfGONoszjIhHEuAF3M$$s;d=7 z|9~hR0^UV{TjJhkthqD^_59MGl1dsD>dWp|g1(XJ3r^8#V&;sTFwEW6m)>?HavZGe zRkkWP+T5;9Xsg#C2`BBC;gf=mL7sP;nf0cD75I)Y_4Y0)w)e9P+x}xdg5;Us1gt+5 zb0qX3I%9sCaDFHhbqGy@Pvrc;VUqsJ5y~QRqtzVEnkJ<78gO(NzE@vHm&B^O$1wo7 z8&HRXspGMMRI_rmivJk}NFAx*94kB9nU2ST|90oAzD)TbLa(jRsJA+2y-1Rc*WVbTfj`MnAt|oy zPrzpcgRQ`@a&TzGpoH`>U^aR|vSYKT#KyFOlFrbp8*6JA3@KMu80g{?DC`!^P2vO5 z*IY94>S1I4Dhbi;KRDxARHH?2BK8InzrZhl+D$1D*U>%cDAqD`+#-MS=GJ%Vf^Z^J zT&XAWJDKREzAI8f))f@GJzwaSIb}7$MTmL3W9ko8Hq9#bTU2V#b?=YNTjqQ%3nk;Ngw)tev%lGZm7rxvhzP$oQ%P-Tn zf~sD%;8v_~ZaS*%M*`rH*ED)R+!ptZb|&wtSO4HFiiSbte6beR7q#S`c1J$1ER;LiGwXW^v&-kv=B)U*c`|xk7eQp zuLHV^;OTa25Ns)zf$L!IN9NS8Bwn*?jx}c0MU01Eta^5g(Xm%nM*i%nvWlTeq&$~3 zC8*OKKc47wxlN*%R;J~~Fx6n$MXX(=F^Eo`6t~=mWx#vEH{aUV1%l;VHBBdH+M!SK z1JvXFT;0jE&NBUa{Q3xoktRlZj{W z&3g?K%!ZBwwGt8e+%@;%s%$F#9dZt;m^38`Qtk`pr_x-24M_R(!3aizVu`Wj@DM58UgZW!U zUf=ey(Smsd$LE8ch6y>$BhD*)!|mbv8SiRa}leq*8vhg9ejZFx8Pqp`vb% z@*C9$$6kcDr0C0x!iw^(xO`wj?9(TODis~6aaY`&@O9Qv!z8Vl+`M;W(+qO9?yHfG zH!Ttp1gbyYCm$#oxUo4=>{{_)tYp4LD%dpy^tv&10T?5hu7hBZC~4#e~KWd8-w;& zpZQXZr|?C-=BC6vefmb8Th^u(!d1b?=dC@W_SRA{%3RI>QNx2H5~|>0fj_bPmp|C~ z__~R??N|ZHoxe~vE=HiwB6~jVU}-aEkm!kM`m);`NuxL0_`WKYvb9o&JOQ&!EBN5U z_nRZ%?ivz##yb%-d(-j2gM) z2Ztc3&JSh2N*XfO(r*62QLPo^>s{Y+GBoy@j@5PkbWrIKK0z_JKG8N+gtqpEm>Iz8 zUEML4G)Li-qJrwLI24Qndjo*e_5tuY+0$3{^Z_z3tlj57wt7<17q~J;QjCYvr#QF@ z*PAXj-Rui;-Ga7xHH|2XkT)nwA+JC)v2Y4^GhI6)aqoH}eu&EUjFrC(^ z9Lm4tfa-f}$k*TeX(hJCJi`=zQo>}YH$i4G>U9OQYYSq?fys@`nN?F*OT}7&+_~|( z6U}mWD_h?v*_6&uK`Jkmpn@=IOOWsLHz4MwzaC)z<$-$jTt1=#>*-!`Gd4ouI|AWd z$9dJ%BZD<*wETl}ujGqW#2z!m>))eBp?`4tq_BBYnPB&SelCvbIs?rc zQ^tG>pR`yb=&S?r&{(?Q-zY){gjnIX+kY!?;V{2Qe5BA;@Lib3TlHTpmQl!z%)J|q dZ^=?~cQW_Y)eq4h{`q?T=jea#z~evD{|i!TdZqvX literal 0 HcmV?d00001 diff --git a/images/monday.jpg b/images/monday.jpg new file mode 100644 index 0000000000000000000000000000000000000000..01794dcb2aee9b1be8a9664e79b20571cb6a802f GIT binary patch literal 48983 zcmb@sbyOSCw?7&jLUH$?#jU}uxVr~wDH`0hI23nlf#PmOf?Ej`Efkj^#fw`hQ1sDz z-(BDP$M3DT-dgX>n#?DA@7ZS`nVrv^oTtU7bpXDKg0cbt2?+o|dOiS8o5-xnva)Zq zbTk!|)#U$6flLl~E~vZ!fRnSQyN;p^$jBH1`X~FJVqxjw`ak&paDDcBHUG~$05HS# zf6)KGiLk7!JuIJ%&Yn+C_vgx=gSC3b#J2y9egENi|Bd7SVRtN zg?&6-J)doc|Jhr*y8pu~&sf~a(d%Db|H{9JUsyZq>powNo=<9k2S5j)2#|Tk|G$?1 z^ruS^03iGa06HNJt6D8R?jr80Z-2SvZ7vSlGa9^bEYx zd|(kVNl8g&ZaHOHaU~%MNwI%Ukbv0OxH!1f`1sUftPHGT|DVg#R{#MfvLlK!3K9bV znE(le0O@Jqxd8xVWR!ms`+v&wH&g%`20G@mRudn9go=cMjDm{t0s{>L^`9X!3IRPT zAs?DFI*~R8gGB=|zZ;`~pp1^YM`%iFVIK+EEAYFt4JH=d#wk)}Svft+Fd_KaxqNt$ z{!K{(d^2;q)@zH#tl^n%DdI+Wbneij5FM%u z4ssvt4mg0)31oV^Xduvsn#9Doi4~D49HJ-OHsU-Q-3wz~lh{s!CIBLF!FHU~Y}IOn z8wX8)7+sOD1R`Sc?Q|6w7K)QQZhRAb{sKsSjrp5bhLuWI?)PZP{$UD~lzI}gmR!s%Wy zYJ;}Rp9E$P_>X6@4PvtzWYOf(%GYZ)?r#ba7+yEWdVau_vIn*K( zlkp7EG&N^eM#O7nz1!SL)V(sh1bLA6R`W=b*EibO5i!P>?q`bF=8T80z@6%4Jhql? zFFoUL{+N`d=-Xpg?bT9D`6fRBcJNNoIH!Fh+r5(~p8)vUM>DsaE|iXtgfSEJ7Nb6^ zW5mX=3iknIbx0QPBT^mIup!ty$EiO~W;eVKVrchX(d+WnT^}_>iU7rVn(}@<$?nTu zZQc3`t*($@W(J~Gm2SH`gQqBK@BZ~slai_?hO1&@a9OKZR8^nh=u{x~K>qD>&)QCQ zDgT(lJ4Nl@cmgQoEZb?njV!q(i@w``4liNGH}sQt^1?6F0KB~!&)q0+B1tgU_q2>U zAMIJFDS90PTz^l%aW)iIm2K-6uIk2tU1M2cL!=6QlY^f-2GwcC>n8L)LWX=L zt<8c+{=I%1B7l`=IZ{`(nqNU#gu3d7tsq0^$q|F~|XJI&zq05~*3j-6SrJ%#N zi;{CNDjIUe^k4J-H-x0PT3yGip$+z*W>?wjwK8Hq4Oqz5yw}#}q$knhV&)a;XJuWD zB$dhhR2~RQR0%uY0p07tDz{u*K8S_^e)nmayRft77Us*D@8%N!C}}UIF2#!_(_-__ zz2H*(65z*dWCI&^H>gTt-SV?+rDir1T%Hh;e_tC=@2hE3TW9Bl2%Fq``x5s-k5g!f z-TQb0;-KgnGZ~fSB|H|p%op|{cg8I3GxqYg0?@TaA}AF-pr+55QQ4VL+Eqlieap~Hp z&9@L4ravsgtFcIReGI*59QuS&GRy;n(Wpv#_5jRRk*^i068j=@+ zXN6Or2a4>vv~|DTQ4w$nn_yCFGV3^ClBo=S9@xM?KG+a~v&_Bm%=aahMXnXjA-S#o zIxU>o=Oty+&@ZKMy^*lA4WsRIH?Izs^YSN|f5?2;eh40Bopaetf?jgHe9Y!q0S1Ma zZ0$`2t!tFVQl(tdFbDg@T_bPh%@y0ERq_rF){{Q6mLgoqUS18&d{3|`ah;!b&xi-H zLH7%E^}CzYPr6FIF561PwK>We6QjzwcY99m8TI7G`-kch`LZaD$;Bc|j*HjW^fC1V z!;5US0MrRRM7!{q@wuLVuYP!m-OsAoK!doFp1AqT2M?uHjog3|z-eApDL`(>r=^SL(T2NpH_S9YcQeF1T!B!q{`A3EA2DKB#rgdV|4t&%N6<&H7zr@5SVF4RkF7+ zPQg4{tH3wZc8mKw*Pf@6?If)C_BoQ5_oS?j(o;~8TkxAU65OVl%ZtlEJ*O~`8Bqv* z_>@zuMrqFFw9s4$NYa~thD>{QU$v*Y(m%;o!lLHTd4EQp>Jwv3E12lV?H7p!N`~{G z7H*7>lF~fo4m8u^(l3swp8&I5BIi+l<&&23r3{Y=1W$lmd~EgJVyTCy{<{J> zpwwa-`f)IXnx-YHXh!;-VEA68myej)X02N3a?Wl?ul+h3Q-_7?soXW}l} zQA!73HJyzakoALm3`|42gVL#4a|3=TtUXAoo9$jd#Wu>R5=OwuFvA0n<7CC*`V}$W zQqVG}-G8NUMa_DL-@!h-1K_4E?K|?JN2=CY!a%L*6E72+jASxnBQAHcRu5l}JmL~0 z+MGSQHjOMHi4f|LFcgpLT--ZyPp4}fuR3nYb&YOIEswoASn&^07f~w`s`?foIO}{J z#7Y{o*#HG0yT^`@8=K83^#XIO#8l-x(BmN{^UJzsvw{4({e6 z!q(lBEj%=Y)f{|K3o7fiGUJ*hjk`EGB{M6sS#$lc3g>Z3;1aX_@aX*wXE$iUCjuq} zC06_pCh>ynhkXlgPH)*>R-3nh7B<@eL~qx`rZ+lxD~?TXC17YQeL@@Y7?sL`E&K*qmrm1pDNW19ZmRv)%F8Bc%7k{PAb%@1CV^vgMN1lXu+rG_g+r`d+5~ zRsEV$#>g0q-19H$k11!nRe~*RMz76z7VlqbcqMmce&4u9K6o|268T*Om4|43BJ=U{ z+@H=+WaZqAE3MZ7h36{sNmn9L0>bWTHIMr*chfO`1lpb* z9#h7;FI~FpAK(53@#tX556UyjzPz$V8hA6??5}ATpD2>c=<(UX#*2TV$iYXx24_+dt$?>$_0SChVe*7_rbH^f-KzKM_m zkd>PDE-4z}Sxk1KlF_YqM18Y1BRb#cu9DvMwegwDZ2cP{Ioe@2>FeXRtLXet?X#th zH*+sm5wrSr>sa%wC*H0D%;ygVEmuFNPTc%71qad1h8)*<@ZA{}(N<>$7#~O<$0<_7 zd_gaPYl@NkoSsd}uU6aI7aZI|ze0NQsWzuuKG<>8dVFo{cy;sJr*7|pHzZk%71vF7 zM`WceRG>e7@Vosfmboz!68a+$n)?@WG;MGP#x)>yCTAy~o=)V}PX3)t5b}ChVZ^@1 zf?r*ZKHX1QE)O4@nRNXqtN_1`VUUzC^gD@>1_h=n=J;FYsc(idCdmzF4?cdLdYp3^X3#ukw!w+3Kv`>ggj?-xh**ga|Ni?+T3B;9FrI+ zkOSp1lVzCbVTtusuco`oe;S!;3kU3dC4mD0)-=a5>S7U^&KnV^SX{#h7dG|i6E);0 zOg+nlLLHL`p=_m&nUvVRL zn4M6>hO-mMuB6Y&tEeKgzIgy*vygd=nJclb`$1Ta12R!f^aLng!7DV=UGl9^J5!LuRkczf=o7BY1H}AADD%(Ob7K(Zk`fJ_5X?R+N>3zB{t7u;C0r-QY zd*!8{kG&Tl54b|`zJSB?Mvl^TDl)s>+lg$n5zQM?SIcWtg#L5h3nb`9l~6v%$4!4N zm=yL7 zL2Fcqq&1@3N9CZ=sD}Z`!TU6ciiZuUFWXs1)u-4lQ@uYnsx2TE>$@&3`Q=imY6brGqQ{cfPBZ6J|DBb&Dr)Xn*Xf2Ya_=^dxJ(FfWEGTSDwh!M#K+cgy&5z`Djn6%Tv&Cs{PssTcr zD?#7r2WJ=CUt?!OMe4*;qmm22;kFSk%T6^Mj+j~+wwQ+<{i#(9@W^Ai zg=E$%x!regW0H`MF{?Id@2j1+V;SlBnkbjERqEUT*zmB~)>DVtg5+rB_bss>{fA=; zuyq+cRG}-Su@6;BBo?MMj)ktMTagLyu$3AB;ys(a=`uWpfSL`B?sKfJrSWT%754(l z9n0L~5j;Ow;#`Qdk&y_nlUd#2$f_*wr9EN)UOcg3BH_Wd-{efQ#*BCTTb+r0jtX*Y z!JlX(#yiT+yxIO{Q#wFGw)ql#ps?uM(X>B0W5Y;9W=`tN;>$f)wZd?@Is_@dY%As~ zvvBX2#KzRWd>JXJ;TaA$n0!d%v(7a+vYS?X(okZ_>1U}(RQdNCs_ z`LZiYMw4@Nv=2dg;kj}lnPagBiaKjS-OGqcG`{B1)NeS%y4=)jz6zT5s?cGzYQeCA zbYYmantN>ZDF3p{VrARM9l>Ot8oH~O=B^!I8jmhb-7L4sm=C6%r`q0=8DE(t#yxW) z&X%Kd%_QHfitl)LM3PwkxzL|iBso5LdyNaav;0p0s=qu+D(4bJ26UUon>}>*S*@>M zn_CWurqJn6@gHoP=iZZvo+noRylN^Pm`N(pe#FF$B52D_KxI5}x+B7;9$PkCli=B> zxxiLEGM3X+lGMd!Rf`V<3*4x7bqyt0vr$uJ2CuxmZU>@%&|XjqB@zh z7?SLsR{LK(XCwWy?X|KLxIi-CP%6syGSTdJv29uZ-|+1=I^%nBsS@|o&vZ0`bl>Ti z(0&i%D?ua4tOx@YLBa-2Pk>6kN4H0Z4rp-4x4#L0xk+lD00(++a)Yx^vb7(_$euY9 z(rCHvV3=luiQa8xT|VddtFDX8#o!H2vkp`=X8{8#ze zmpzNviB$mV*T2!F-c9dO?p=Pr)X$WB)30hbEtorPn@PKT5@i#>CAymuI(S5U&S<_` zr|*tKlIZ5F-Jca6^fApzSQZAGQj2xv)un6Hjwg1{_m_tECl`-+sB;#Hal{|7jTdD) zy4MicsUowhI)qd!u*RwSf(jj?1(P=xfj)U`09V!;)GRPPFF7;5s|Xf4;cpA;k5OH@ zWczoSgc5{{XZMNPt3dX&cR2KnZKI;J)q>MRH^bXYsa$-=R`BJMc9xn-)YeFG=x3zK z?ow6d6E?bBlL-SOS6Jdj6i^nsO~tR=LP}HEqKIhUe?XLfD!p-<{|kks@2Iu|X2Hk$ zL`5KqY%d9K6Q+q1>t;~qTdHm@-u+y(5S1L%9ctZ7>TPC2OfIfI;%Kosh+C{>iYiT6 zf#6vY?B=(Vv1~dw6ERbszT0P->10hdvtue1WSU%%ke`|W)aTSlt(FAB0wgXVwxkFD zfw(r(JLy*S8$ely&XtSa&`euWsB3#_JC8cv4M6gn@w_6Y<&QYNF?U^*duxeo?lY!? z5Q2HbikQ_?ILt;f_(RK?%lqFw1Ee9m%9x9r8pKb4s++d!sJ0ALTCFe?$4d116v~=k zKyQK!khU5Wm1Sa9pbh49tfVB)h6CC_gWma|X&Mutc=#Dos)JH;5F4Y4`L%djfqO}{ zB#RQ6zlt7vx5=mv;`thmg-aJ7vVeQNHSfcQG*0QZ)|d~ZI(v`Q;HFQ2Mac)?hdSe_ zW*%D@+q*6b&HL)=$id5Lzo4lf-G*#WfDCW)@xhq`WIznDA;-De<{Mf&ioc>sx7?-l z6)STT%^3h|S0G%1^Kg8dunD#Lwn`Z*C9U(~lX?)JSsB6unya znFh74AnV$gPd@js|mh$yz#Of4LavnOX9{3((&ILOU`XXLAW&;R$4bG^{CeqwxD@#l03}Fy`f$a)!fPrJ z&+wMN1;9?@8oZL-JC&X=(f!Fw`+>Q7ny4Rvac-vm{_pz9mj|jXA34RVCcNy^Gv!_Q zLLxO6pf8@;A-F^NZz&jCWj@$5wU?n-Es*NeO6sGeYUrQ*{{CARYhORTn;3<^hw|}_|a{6h+ zG{1<>fBgYY^|6^>e~a`L6RTOZrt4Li%2o)#5U*(NCV4?V*IbDXj2_;{CopZ#>~gQu zoqZ7mn31H-ujCOlcsz3+c11&pdRm+qN;dmS^{wNCw{*ed55WIcfB2 zyKcUAn%{q>Pg1G#c56d>C?8x~Bp1IOyuBHgO8pJt(!K1Eq03DEA{GBWzh?Fij-0zN zbmKkDqoJg#yy}nYi{Vm?S22mCjDOuOtdyUf9mo?G1Q&_(+U~W-7l$;C5vax zvi8bRW(N(%lq+v#q44OVjwb+Q??DtdODw0<`a`ypU( z#@e0IVg!e8HXCAw*QzdJV0rDOI*JM<2e74M`Kj0m#<=1k0VM zFb#D!U~xLs09b1lvRH)dnoA+(zVU2o<{Q(Z@)uOb^|ZRSSA=hDBMdDl-X0+*&l&}m zQdc@)i5I!jB*&A8n-+6=vj%k7CL^)JS_-YUActm9Z zlu(!r`ZF%H_)jGc{B!rQs(WQtP?U`p>2xEdmgpNh?+gEG#xXQD@Vi7cfc~gn0+5t`ARl+ty)#z7S?&+3L_Y?uFoX6U96Go^B;DDa!Kg9^K!P34CD9$6BDb4lEX2s5VxBHY< zgiovh4yOkQ37*seJDxjSp6e|QH@OTu{)w%~dX-dMFu+Nct=4eRCi_r!Q^zPH%o-dC zxGLUk3~ma))*cO-ItpyU-b1KtjZ2Mq;BdO{n~*kg;Mzz7p8zK3nkMHgdBR>t@%%#0 zRPrpI*NQA@!>;2xU4vpPht<9u-w;97Gh1QClq{BzX{-_WpN+3^XzzEACBKya__a^F zoAQmCM{`=D(gDUr@JG;yW^qhq$M{0!sCWF`uH~|jGRm6O!Yy;+)%di-i|i+WTw!%- zc2_p`*iG*@hZ)hXs zoYeO)-Wm}-%@k=?tWEY8tneRf%jvH4tu4P3mU81;o4nVpo#|ksKp%<)&(St_&-$tV zlHatoZ_-wF*!zNWsL)y5$I8fw9r_vN*Kg)@-JfZyE~V$K(H)Yx(94>L!4;f-w%ami zv7IEn&t^27VZY^Fe;kT^^NHSgn036$*lZ8AEDlS=>&U7#glQe}i9N1mQxET~i3%^z z6k=RFzhz^UE!cbgoYvv%30xB%MLYqF9$BXw8XARFAsh$A8ItF#M{n41Xb;1m(~=~| z)>!y8aqu&E0z~1Rl#YI}J7uY=Rbps2#lO=z*Xm2eTh)=M@U2-_#5gHV53qd#m~C7% z6s@cyZt8-!ddZ=e8{JYGS3Sl!S{!hDai`6oQi|9&T%qt7XW&e)eAVv{Rf?aOD4v%u zUg^I7k4X~IevdW)6uGbwsMkI-W-v03;%1F1bYIIqCDi%b;OZ;o)GI3-toY55Y@c0nV3rM zD!i)PEU+8GIAJ=x*PVZ_h>=p~Xo3@{El&}%R*sJ)!I^kzSa+pQmA#MW)I09?K}I3H zX+5Lp@ZnAWgzF?fKd!`vgSuwCctC1T2&#Y0P=zVQ+N`+2La;T)NsUW09!c82iW_&l z0tSB}?v&^De>|=~s!pH_tO`btGkbA2&50{B(#%5oV2cPgqwbEte6%u;^B>zQhm!$} z(M={*5?w!b4d0;R9{#HPywH&!^@gz{z{n>@EV8=ps&KRL`*V*aC@ifsEZsf{%ANo`kJT0r*1lPgA?S`3{TJfE*;;Bm&%{~D)~=i2Kd(l zu=b0y)Vqc30Sna_*;Z$yj86rYZal)%?Q^9evJvPkW#CEBt+x}msmGEi9PLr zVe0T6pTJnn;Y$WurKfvDyv1wgCQ8M! z+&WB2&p$mZin4#wI7;D%+VX52A9~kwe5He$1tOXvet~b9h1se@$=Tr1`pS(pO$ViH zOJ}&wPFVI1*#?eO`VFVA(a^YDx*7_}iG|fp2b9L_E6eN%lSyVij&lVir&aq(q|c2# z0Z7&^Kh8&j`Ru(DM;eu!&>}0YEUYQj$adl|`^6%eY_V#|jFja`rzwk}(=D-P{skEM zvuWkbjl#4C!xW`vvm>>PR&QM+MctBC*65Q6*SL^uLbS*P3L{dg9d&a{ z#@qdG?W9!-*I3n!FjFPf>Wx$|`Bg1?5j|t#1 zZ%_n{kaT92KUwC>TQ{bhWa!}h!Fel)ssC%H<<#ZO&`6rKG@f*W^aWzR$-!aOut*RI zk-T@6WAHWrGGz~O&t*v1T6_1SSmqUs5sxM|EhSP%L9`>H2Xh#NCI{R9)o{=vXFw9D zN-4iNWsA-eY_P>(6y6TUuq((p1hd~_T`4#GqRE^#YQgzV3e-B9y?1y56oONOKr$zx zzId~ z2JlNf(`=!49^uRh{mftOVWZtxw$?V+EKWlW>PX6zxviBafM%A)l$EQ#_u#Zrfc|Ik z9tT;PF#IvvArky+sYRm?`0iihMnBYmS!=9qG=Yz=j!EbBbx=?hGb5?Yfv!+4W;(ps zp9;~re410mTXqjI$HY!nA8r)k-3H#7FHzJ;!z9EkQ}cd!8R3>&@C`_Gni4PLNuU%a zw(?baVigg+^R-w!As(LO7b~)7->oKR^xIyE#r$no@2fHVqa)PO4dsta@;pjN zV?}-}p?WD(+;K5qeSApxa5P;NFsB)0uG=8N?nKT10*qYUoG^0^fYE-7Q~0whzV^PaXDjW0)h3U2BqLC)7T%WRbwUbIO9$vl!3ZM zVsGFbCKe24yuP}FGiJtUjjbL*c1wox_K55KH)kf8C4{fsSYMe>;u5`^p@?GR@>);d z+uQJtx;krOaRZE~%Tzg|$Res3C1!pY zpC+Tv3heGx{_tcn*FR5(AsQ^Coi0-MX(#vz6vfOvMv2)$)2S}tco=nQvb53|DD~sQ zhGkSI9DlZ2E$;gdwuIu*8YSF59cx!&LVbR_kMgbQg^b00!`i=Sckq!mCsZV})Ep4o zHg`%Cic4<;i52z%OmVEWl0~D)z%VYV$Q>N_>iF9TLaNQA8K_c8%h!G4t9!Cjw-#txJqEL?#3^9k#-Gh^wJb-{LNW;h{W0jNVY!R z?Q z;_6b)-b5!)!H4*}i9_WrPtHQIUkHBuPCUp(p8d}7FEgP8P7)bebOD%AWk|6K0AVfSX7E4= zmojpFL>vcqSnM);YL+61G14t3PPSWTCnIWFM1&iyBxx~Ili1^^zQLzgkL#6wb>DjX@*5~69Y{g% zcwFyZ-6F!^3y?>$SNfZrnuKL|=@abY zb9r0ha-SRzi$OMaW3~O_38d}?T9S1f`vDU)_nJ2RW@{l;KGy|c$ zhm1V&2R2)~G}E1V>Dtgp>BElR^gSL=ujqv{F4n*Sr*bh0W27ETV1UDE7AM_Lg$u8- ziXoFdbu}#8#2|gcTF+W*0T-Fq1r1;aRYNFYy{TwtFc5ti)zs)9fNZmEO>ilq?2F$e z?zU<>xYR4`J-%cF;Q^XnjZRv!X|(L1!OuE^NcYb5^ zC5z-D<>;{6emN{wU+x4fSsh^RUKnE21_lV>$&fxTX%3GtvGCPSvFy%~^v(poQd8=2_tw)JmV{qpidVRfN8Ggg zwu;Ap%7NOKbc?qL8^}hB=$tK~;U{7Tpb&3rPw2Ly)nt5`>G=w8M%_3raIG1%XD{lGNB6j7A_Kw!C^D8*I2W~YUmCHM`shOqPpV6jT)*HFt{!HdGw6egCx za&fKes};@%&H2hMNFbTzAFTJ8zg3lXdlgt#^eA2%&4*|wvtb;Qx5(n#OC56_*v}^>Ezl-eVtVYU0SQ~xF9CJIkII+>C;tpq}J5}kbGdZpb~(b z?Z?iJaINrTI=Fr>vpBWRPNcH;=dA&!k4DyTuUv$>4*#TH6jFptkzIXsh@Rjvc<3Z7 z)~__&l#LS=l2GoAY?rXi&s$hB0>G)j)L0spHMHXZQUi*#-Lj2bZxvp<((BpIsblXC zaFnE0Dxkh*AB{(f(9V%(yCt*0L``3^uMf?wc>>HgCe#R|r*w$$ z!RxWmJDEqq?Q6&s-B;+p_@wjo5+_`%D&LGg=q$jkW|Esfi(F0*xs%=~_O4bp7teit zpzS!wRx&>55cjCxJ)j|84&c}`FZV7eWY$0wR1-PIR9n8oF5FS=Zvp4&z4zYb3Ap0D3GEy z37@2a4VLSO>N08h!}k55ZcSk%;{!u^W16o%Q}>N`G&hB53X%x8UpjeJ$8%mBficB$-Og|Drp{p{du{bthGTZDs(o;ZQw z4p^LGDFc@Y0|_waEBslY^G2>E#BZ)u%# zC)6{Y?5TV`CCFr4c&2k5zCJZ)$BrSz9nQo{R&)BgJA*z|TZ;zL@5PdiLU zzJ5p83Inx`CB#JOlPMd+&1*lIUU^2eMQJ@EN_jR9N@=FjIVIH)-y#y_cb*z8kjCp{ zFwXSgFijMoN^kKqcWGm*L%od0=l5t3Czh$6SII>6kcHKYKXs1OiWVqQds%6m_AV$xLS=@q37SpgI;Vr)?xo$Z8jkhSBYb5AUm#MZGG9k=H$)#LcW~6NFXGKAsxMq_#vHp z)<{@?);|;PT|KIkOLs~n4}Cj;u@s(%MNJQ_)6Z;Aw0<{jMVz|5MTvyDmE^6gf-!~f zg{#9S5q7R3UaduEnjR2C3N#f@HHPKw*+uk^ekVw$0WkBLfeM=9K^<>LqXgmj+AcJc zIVzs!apbD5GA{F@WWXIXWF?qDyKLlH1mOa`_oZyCjEqwlekQF9UWB-ovag$_8%8>6 zq*oq~{@2o+iil}0)RdZzy-zo;j82tyKv3EIQKVl&n6xR0nLIqtfI#JYhdWMuMA5{+ zOAZ>FrerpUKM8b77GDOjZ7@124?g*S$UgZ}v11Z%`ffIQL@)tHNxWS|bDL0FH0oe| zZCBbcTtfn&)mC{gE{qJU-$&eBCF--}>m-ywKK!jW9s*io7$Oeyl9m4h5hd1V?EC=7Xs+7;y7I1HnunO{_!_+PRtMiM|2DU zw|L{HdsrM2{Nfzki$q^Sfy)T7sz)<1%*}^I3mot5o06}X_TNatB)VJ^UoI}7{K>?4 zv)J$j*U86dkY)*UlYRxq|hvLv0Wl@0(>=NDxo-ViXvR9oZ*{-Pm6|OYsTV0||Qd3LWQ1JENINqK%=dajcB`+nVZkq)d!CWt@sT}+w=KOpVhHO! zw`HRqY1X^Z`kPj4Ht8JKT(Q&v3|DqDavNRjVIWpZ)q!y4SSa)rd?DdGu`c#&tQ05L zat7^pxw;>_mw5AgF2oAw_EGwTdqzuiB{B4JS@O7(&B8)>n*|abYjL59!#7o1qSE@M zQ1l`8twjvt>h2c3#jZ)kDi2l?XOV~m3dr9=5qbl{5u&^4Hwm0Bm>7|fEhJ9jkCqDh zu(7kO=JEBSSVQ4sfBskUC^`E|QWk9~kqKLyDc$2S;42 z$TU2m4i_wwa*r=8cvdGpFA;VGLgKZ|zy;-65R!1l?^rr+P&i%7>gA(5P-02UQ>rIy zHWIRUi2?QPl|ZbF*c70RkqBhoh;$rEW4o&m{>br$2r_({X4-WbbrNx=9!km?JJT_( z1^;3+hR6>D8T2FJ5v@pehsyeD8uUuxv*pNXug2U-_G#IOb{4SQTiW4PuYjW?rc7Wm zskP3#d>tb(T%xWHK)Ae`8-NHv@4~;e_E#;|eegpkYiT9}niiWUK*d8gH@}MI);f$g zrOFu+fC!-ESto4v1TfAC=ST4dj(5+vHLwt>SqOm%W@phNv~-ZOu;B!giKI5O@oeOy zrchR#wDi}6kpsn+btz~yg2CXMa|!jnG%O4QVuqFBjo(6LMg@i>yZ`u!x2}A8zkRI|TIBn<@zx zBe$!()pbW^BFjisR+xgFI;}#`+rMnHI)=RUR_V+8M&q43cnZcPg3bJ`Da?l3)Sjei z>`v-+c=S{6ijnHM)x?E#29YF`M26dPy!af}u3Hn)@{N8`1UA3T>RNcN_u~NI(ua>U zPsrqkg2vS(ObJVR*&Ldrwt}qi+km*feQ*@_USyzcgDjJ}|9y=uv4QZmA=Y`mALJJn z{l^bWXd-nVZ4x+m5?VVNOO%FyB|+fU2CWX6BV?fPG8_7+WlrfYnoy3eaRVkMn zSxBZ?C)=>eZ1nquA$H0qfL<$*9i@{?cx-V6zK}P;mAMx{u_7DQF5h32J!3aXSER0zzcIEGE{SUX?v+TA=gW?Xsy&)R z23B*AvJ8vgK}|7$gf>nUWf4KwQ%2;CRBx^nef4K&<0JEoMM~(mMT}>L7er$Xs)-QQ z+H`19;@0N*8iVd8AbMJu%w!@QN`B-Z{`zoZEOgsIR>|A}&rdV>}Pq=dE~1=;M1el}3nSb9_#uZ`$6d(X!-a+-{U0mj(MFNpec zc*2G{!tsEoF+U7Ghe}M{jA~(Nplu|{hCQ{0A76x8ajKkZ&}7EcFrr4+=NNrbWiL;~ zO3GG$u$n9tCiF^oC9*L#-ZC;N$)`w0gB1sS@uHnyf+_rvu(qMJvH&GRrYev;eo0|i zeY+ZDE#0RiLstg%DihL}Y9ZG-NuK#E0}*&F)Zk3C7DpcLlKC`miNj z%fFm|&peejs+Mi3=)yiT^vs~?{J3*(DyelbrZ&ti1tV$FZU*$Yhw74&STkU5Z~ z;dNyiB{YkS!1g4ot;uo(g-M&yj&12?plQ&jay4pxsD2U(^r+6x)C|dCH9e+ynSOcI znxoY>i&a*xau`zx-%!Ef`^85z21Z{vdTK&Me564!=EZ0l0m^#PnmC8{-~!xnN73Z_ z1grS-pP^2bp|yCR!XO&i2rqt()MM7n)m3&jY$o2`7Lo!}5@?}VSKshx=1*Kki>?}n zT$(l(Sx)!9=ATAP`JU6l6DNIHf5TZZM2x6iF!xfR!_0?*HRQ$}^?I?u(Sb0`Q9@F} z*Kp&2y41?`vGhho=V~~upOHzN?z^VWjIQeFBX=O2smb+jEUYSg+{xwsgdqa zP(xTp%jlFXvylU8IFyvM(|BfkrgI{CYm%|p_{ecYV6mbD_#w%hsjL^x{Prh)ouN=| zRsmBklPb}t%<`_IVO7yMv=SbzXr!FEu#&7>*OdS#t*sYaE3<&{hx}sX5ag7&+>;9%S{u4|7(lxkiJE_1Pk6WslIFQ{i zNaUfizt=FaN@`R&OYg@hmSu@Bb@$I^fmOB;oC&t_<9jv#GSaz-n=*%b!s7jN@WzCJ z9F7Ynw;Q3DU46w!mCZd2(zCl*R51ehEnOn>yoP2^ZKD2;!p7Hjj(t`gsd!bo%8r0m zte^?4DCnkE99thoo?wEC9HXz@Dp-P#oGQ<}JNK#=*9AyTEHF!)Wh6bIUn8u8=S9wp zo-#;GnywC1f@6X+Ds1+~QE%a4%N9ed0%nZieJ+KDlyh}6!b4&?-!dA|BMU$+>CFmT z_8!vU*av)ZeTELL!1a_t-byAC6l^RsW+ke|$icJH!LY%*kv^GzMOH6@zbT~Uq#=dy zZw%>v!$Qvq0%v*^zB^fWXf|7LvGW^)Tw_|bhfh^>@B7gkx0=Wo>9kPv!zA`f?w@|L zqKih1Slm>(P`1Rtdy)^K#yj-3#@Hi$MjQ1AvRH|$iAm+DM4r>e2!mcbC^>XpJP)w_M<-~(-DOF~v zgAD8+b2am%&aQt--V(0%AEdtomQv=bhPd7oqQ;vOQ&9kgAkRz6xZr|hJ>Dzh)?RCE zGBf2tEx%&cuPe{1#3-zlSE?p$5Yz8&E*-i{^04~&9r?s~KVd!kZ!jn9R)K>twi1R^ z&^q->MN`7>zzmh z3v)n>zb>NZiqFx!cWPK^8Aitoqr%A#xh)qU8c- zUND?Vd}JJgg|MD`DHhX&%0{DApmQE%hZX1(G?3Q5!*a1+|>kA#R_@%Rn?RfiDv%d=78fuZT;G%S3j3C$=Z6w_pj>{_0gOp&} zqYWrHMrjRNIuT^4Ee;vVjuiK$AzDatxwtA>EdbYYp624xWMFZ}DVeP{I7AH!Ra*DP z;k$)pj}I1?4nfv37I0LWQId`i!}l<@t0^W8!onLN7WZ~at2d;Umf+XgvKu5r!^Y~Xyb;p%-1AePr(upk~t(#3{3P>G_Evp3-K{Qg_E0vU)_MvgE6A%O* z%LS$bV=6EZY_g{9Nd77b2qcVQV~!~ZHl&OoM!I5S)r0D`EIWWB(QEpHebC@CDO=IV zT?jmsRdHg4sUxx4xlonl1f7k8zt&^ybR##Zg^+`=hn^-M78uZ*ngqZi|Ev{6ejv?J3$ z%!C`fD4nj`VG%NYOLo#!vW6Ta!sqUyj(B$s+@!WgIg-bZ;SF5jo=x zQtb@VHiJsa?nasMYO86b^CS3#%y&%Ebd5xEUeW>hEs*IBak42-_U1O9#{#ado+0YczTDQ~`jz;M+c(d_QkKm9Pj_-oIGxu(EBv16nS z7Uq?|Jw2C}q{-7Fap5`qm~;3wr)14M3Q}CxA4Rl1CLzG2nL0&fxmTY)Yc2l(3dI0h zcj$vrBP5K6425ml2^sjMA^JC;#Vb$j&0#!S&HPf*Yqp;zP6D@K#}`FW*617un%e|r zSkUSiVc74i6=vmH&cD-$KwIdVLLPBTPPsQY0P01$@P*CFhNXHj!$uU5xNv|v`9#+c zXrfheKwi^9GJ$QOb8M20gRw@egE;rU5Z3O&PL2LhA=wGB#Lzc5;a1c+l01M#1;^y_ zJ%8i>02`mGTD?~17O_&1(p8eJ$kw&jg<4k&!Cbd@uTMT16v#WQm3vI*24LDw&D5`H zq77xM#J0J=k<7+5?k0-4r*lKC!qsW}%uwWT)6q z)wXoC8-3a-Gd7MBW0*K)MA13Ja8p%zNNctgE|1JN*#Qm)h@y;GZT3|vE?V7oiyN3I zg-=hdex6S*KxuCt;P$FI%@Z{Oz@o3L^*$`eD3Dk-j;||A6z|Z=82j7mw(!i^nhQuM z>jpaoV@bYIJKi*b%`G@g%>_-RHQkZp_M~i6U6dOAYCt$aqRJ94glE-a0WJXGrJIkA z6WkEiD5hv+!PxGY32)et=zu$#iIWT3>a@kDZVts|LAHfBs|7$d60oyG(g_x^(F?)@ zjur_Cm2HKt!<5MbT-yp^*rwM4DNfoPtny18-AP)@7+%8UW))ZI z2{zOGRql9D32Id5)&Q6Ys!JxYrKKy<%~;aKb2R){q=l_)Lg{(kD~n_cWiGd5U|~h; zug7x-tMcR}pk-?HF;P~cf0+Hn2BpntvXr#^mpp4;D4R=l7Lb`KvTM3f3s~X5LkJCZ zDbBQAn%PLtD6aS8{LBYr1>S(jF_i+cXTJR(xbhl}HM9<9A+)DwMR?`&6dR#9GCN(Fr6bKNv!h!i*5Ih);#k zyjoZrN_BTaS!Em@KyC1#lS!>@EPHE&x?mXrYr3o(o45`Z$N~aRqCN4rS1L5u$9mWF z$-fYu7RAr(LeatwpHzO;x&%3a;V9qNpH*fBjs>akktB5PQv(C}2-yV?1)}r)XEpbV zQ^aop3mie(<+M6yzX7Sx^f$_F!j@X{Nb1jy=|&vjEAi;$lvr)r*=c$-aQ$VnV`GO6 zXe*n4$mipuAH?fZr|7SEx8Yukq(aZEEnG7ipcHR$T)O^7oq7H&fo_XW9}Ni1aSj>~ zZa0)|v)c+@X)%}`wy;OtGKdQ@j#Y94ZDBQ*S}okCjSEL*Q`w#Ms*H)R`pvXX(!$|e04l+R z#!wdkek#578Tw|m({z?d($E6aOX1mBsNWt1uTW*ALROKURnIkruj6M;6umXgfM8uM zd*O3##ZvY7eh5JBpA@4at#sr0H~n#We}x`2wC$=wMRs1cPKO(^-s@ee=r0897P|B2 zmBoLF#)ZK02MB(!$141tHh75}O?`^9;dFQ*U-m~n{{WAB#z$qlN~x>Xh+VK8rq1&m zX1Q8PgNw_Yu=062^uLSzX(MF*tmiXhN>SAiImudgNam7?Ih^NZbAH#WoJiR|V)8Pv zvACNyf)*Ejt6>x;ZbCU4cZI_S=K!u4T^mQD$lJBn2MXM85<6)NqsJa!^0hC;N)N)9 zY4(l7gvpuZ%kJ=`OW3Whxbr4Hsuz3IRugH@h&b3o<&F_4S!dv5d1HkYvW8|BfO%T- z(A<*lvof{_vI15_vX3$BTwR#6NvM>>Xs%jt-FtOnfYY-ZJQzu|>HdClpR)O)$)8{3& zSW4Py^^DXoQ5JhEj%`Jf-r-<%9I0R<2fCL>c?e2H6_!R&G*HSgEFnzES;EUJF3lh^ ztz`lh0I-6Yh3%w=#2=ROtzwePD<~CRk8>SP;Xi+Hmga~25IW+|AstG^+JjVvIU4Yo zdjyZR2FT$sOyS)Pk3=MPO2(Pl<>BJ1C<{*`C;_DCF1us#&;H-`1kF{cE znHUHeQOMfAr3BXjWt&JOj4TyOYZ*WR1ynD=?Y=v$7O|!+czByT3ZyY4&sAnUtxm8;6-J!$njQexhc+}l?-tvmp!hP_?{KWsI6>xJguIbZUJeG+RJqIw&(=m2t0(%KAeyWWVIVt zOwEpa#rH-8%am1`RR04>%ARgWl?xnpZbE2qU?PoCcW zL~^!iHl9&SrUM~WKI6AH>n6pjbk6eK%DqN1%2m09i&D}LDe`Gt!MtTx_^QljyJ8Q# zXrKpf6DqQW%(G!v8T81fEfjc7^H^K)RP48vmTtl(PI<53oL#vaO|=It$`@9EwY2?t&qMOW*EghEf7Pzyos6iYjm2%yi?6P>{ps-V}iB4M?zqEmM>n)G+jiv?2`n1Qq4z`$LiWSprOk` zhCDkV5jVotGMzxmlr>pWp9fO0M+ynbN|&kOK?r$A*&+dDl>`S0JSK94i$bJSmLnP5 z!kG-^;+!cogI0mjf<4|)b{Ks^f@vkP79FnPIBjsSK}`%fSS$sccUVQK<5{sa_kxX} zRuZTvGC8;f9AB`>qNefcpo_}R`A?~_Ep^~4F|%`Ul?Z8eRj}QFm7cWU7@Pxi#Jm*A zBy4UhJ%WCXrNonvkuIInJh*mKD)p69D{PW7RdOuYRws*Kzg1kFtt&|7)K-ze$eQEs zpRpl_bj4UaE2zs=5hGvv*44$`yRA#CYuom%vBO$^2`g9Nn$kH;9HDE87K@*PxW>_r z4JF7O*%FYNDIv}rrZn2Z1*xhCp(8ihVE`zj`jF%#VkJT4K=dkAMAllZ)El}Fu~|W> zZstx8;%13ahd9<OxamZwbu@ZCcJsrUHqA(1^)RWPYdZPD))9L8QjG`Y31=n%3S{oK2;^%9evj%WT?^ z0853?3jt>0h-FE_I6`Rdm=YRu!eA)jbkG3FKxxVgs9{F%6Ev0(FqT0j5q=WO={Q0F zrZ&P*H^O@uO4ls(V@X6^i$=)JaE3-w)qyT`+N&$*5SsEruP7%JvGPX#Ju5xdP_)rf z_K=niR~?^SeO-ybmAlX|8nvR+eQ|5iXn+j`bok?+$y)}~1{FBy-km^=(Q%tb+3m%TtCYH0I$Fj2(IcQXTQ0;tnX?8 zAmMfEKOSy3W#ZA+Q+7B6myqFMEOFq-fbKXMG*{ z8r{xSpSA8Kyez-hT^t)*ASG-7jqRXvwX7okD)bBvAcB^gc-OS-^4L#eiwt4-tOsmq za=aFDvvgTFqdQB^6)9WD;arwhw3^DJEAGnIE=|Qq*;177FDX%3dHrg@z~yUPy!zD4 z%GR{!m9uZv0Ihzl-d>@E`mKu_zBZBEDPKMbf{_80jBo8kt}xVMkg-LS(w~Zg??oB4 zAbOxDBq%17gOn~nK*>Z$y0NNV^>CSVtI8(}X+c><*(zDTS{d#hrVV1k-b72-)9Mu+ zR**VjC!)Efb+k9pL~m#OPZo?>${JEJg`{MW#UgJC8pk}-LfUCMM6(b`ptIMKM%eLP zHSCSgGzG1Om!-Yt-Y6}!JvJB{?=%t<9QLqNaJc-QXQko(ET5%E1W*hNgie@N5~P@A zRD8r7E)DrDdRXvCbLGF*koOXU%(a5awU=EIaADLGU6q}bvAN+mc9_PeHI^_CF@)D; zY9@d(6-HKsGJ+z^VKmSm%>l_xJqT~AF#V#Cmj;BRE%2Fi@atIY1g1aYnn5vZjR3hZiEqq zY}aI`&{3$o{A*@q1x%VNWhezxk)78ypIsTI>3!#g=(PUqVR@Y+>R`PbPy3KIxAC*` zRO)4of_uu@QvHM9^(s04DR(~&a<-GZr2B@|N6 z$+b#p!;N=B#AQl*>=YN0mn4)33S|}n=G=g|1SZ^hLi?opBm}H%%^mpk*=b4Re}#w6F;|UD?-T%ioWb1kv~+4rup>A!CM6 z)e5yX6udRhYlVZ9yT7$1?R#BdnoVS_a1A`Ehgj#4g@M>o_|2u3kghY;;y}*}2|OYW z^+N3BJ|b(Z0?nf*6iyPs+l3^9ke3@u$Fpui8zd>+Vv1X=I|m7KiN8g=;F^4pJ!qmIqnHf3tLFY|IpHTfAyCmY)+BsVzY!Q`TQ=)9( zy1n42#Wb2Zc0Ys`r(0u=6kQ8Y$s03msz^L0&zwsf4=AhAS8yygiIa9HYbmY@06o-) zVHne1*@>m^##Xl2yAY-+R8 zx?Gx_YvK2@6Fe( z;_yiWG`Zzbnp0&I&!=*$yzcI~t+msf$V9S|aD+IXQ=10kk^E789eC`VWTBcjBa*SpEyR&Q<>>nyDZ>fa42Efb3ux1^ zlnUNum7@_hG47x)a3wF#^*_O2lW|Oq-f1RBnoT?^@}G*Dab8!D5`)18(MWJ|t9g{} zR(3C^MgwIOw^7OWFfDoEEj$&ohI@-8US|=iwg_~VRmRK%R-uzqxlNN%O}koJnAks7 z2X$TgyWz`N{YBnN12HNp_}Iy-ZDlQocfHwCXnWm}g_^xc)5WwLqojVLR@Qs1;$RU>5$vr4Qu z>d5{t3WCH{%hUi$njCtm+d)CWQ`K{W zu-4KZR0Dh|Al4MvsX?A7Qj!Zs5q+tY#|pF`32Ut6sgp|%p|tE2FK)*Q;YfI_p;1i5 zCW=Abs@RpUXgA82pD~*pEGlk|+E;EI6w0m;9mrK@4~*M*M$KajefI=1>fms#!5tGQ zi&9)^_*Puq%)-=7sf(6!jy-&Vm9gKk1#-FO7@hAYGFG{~t#hZEV0O1ZVy{}F>fsEm zy=~qB?muc4xsmc!D74QcZ-4PmZjqqf{{R#YEFDZ~ZkRm%R6-^W5$pr9V}zX6)Yg(4 zk`mlMQeWKxK8RF}fOR4dORG~g1BE=O?4>vZ z-4V(YR-_CdbtKXY_pM8mZS+c%??wjkx4PJLZwfn6sv3z0yU*1_(zI^W>Y8@o74I~$ z4L+Hz0*DKndU{?QmfbvhNJAO`ASxP(EyuZ6H7_o|Rj!$&g5|~J)6?{Z-rjin5X*?jn=>S3`1oENLl%U4~f*a&Z^DOd;Mip4t zDa2F5txbwaX-p~bAcCW5X++*rHds%E(yKT{npw9h9O^0Xb)RYyFtY5dx)zHIp)BsA zsSD;UZwUfdIoL#p)#$FRB6&$WK{i$i3tLF^uo-0zBNnvGXR5i=;d5EJ$CYYFxC#Ye zA&jBfCP_A{>oUvTySi+0*L7L*1c&N_d(vGt(CR=Oq(4v!)yb^~xJWw8{TAy>sNUg~=e8ErEvALyo1ChM)%d;E2!5B9 z<|3bKi(FRssmb>sJ_|uCRqr)*Vdr^y)z-MUSWat8g{oR-8l%};Y=k!pXyy2q1hRVs zM3OBD7N6=ZW5(vwvJwb;AA#YeV`{slmKPfd^hh^5 zcN{5TzUvw&QCKWH3RU{TNhk|1365jE6h;U+Lbyw0aCX-Y5z5`{qxzOi!eqLoIRGWp z@uioi%IWaz8&=}!O~&T}wn3xytTDI^!zuNDja1z=kjNa-VOfMkS^;PUS;K(@DrOhQ zHsQh|HKmfqvBH~$1+8_}0;P@;3gVDB7H6`qGqqjJkV|)ft+0nTlp}0{k0?lU9gYwT z0Z%_-z(a-<6m6nx0AUd2KB7y|fkkB^Eg;&8&8O4Uu{?j$C+lV~RX&A|c#N)MJ(4lu;ooXo9Yrwbg05Cq;Q0zKD;g*)w9`d>lmM_C zqs1XIhX||H2_O|XI89UP2ncUBrNIkHCY1N-@`z<#&}h8pS>aV^pthYo`xvs`xasit z`&;gl7N1QeSG>0`c&slWrPc204Hw)mOXmB4E+hlB{{Xk;fq?l~82bb6vNT8rXUlKjM|W5UT|Qmc7RmoY8!h0nDYFsQeP(>4q^$ zMlcQ1KQq-qa6XA}x&XlBnY6;JCx?N;FhZ#jTHhtX*+sXqzWCZjq-_AIgQYeyjJA=7 z01!FWX;g^lTr4FP8%F1pJNv8Q^gC;^87h&JgqP(OCekGr{{ZbG_7bc-a;a)0+lxzz z>KQxa)om9`(`lLnV*nMV>3Ums@)qUJ5sWHGCwrU1qb8>`tjj2OC2ObK$SN%yW9P2w zmZMXHl&wUegym~zpPDM6Q=h9L+@Nl<&PvryIR})s>P7OO5o>gBClUrySTH`P;yHID|`U&uLy3tPHK(*=gk1d2}mF9X>Hs-ni00p+x z>stew#-HM2RSxs1~Rb`YG(pINb3rGcN<<~icf=XVatCPMmNG_z1 z$Zo^GiqUl~VIJ@~RkWJv41+05m9JMMV`YxXdU?Yn7HJ|~SC^yRc_C||ozmV`xyvnS zyw_@5D?kNMBoj#tkli_Bi6((vIdFM?(^_^r1zBVF45|$d2h}-}xq`@AM>U*W-Y%4d zt}{&M;aiRFBi>CG81v04(mFtGTB}*I@KUA5X{0Y^c?!%=yAM@>4H)H1Uwo|WZH;i( zumg??uK2|jY9by(T%vtHmT)kfR;OZ%T~DOPSy?1vta6me93QlWtDBZzqb(%TaA9k0 z2={`KEqkSBTjn5=N1}jhHnHy)eyc1R_@xhrwnmnwwl8;eVN(ZbS6L`A(kon(Ar5Xv zQW|XKG0gyf38v7LPQyN2sAd!?TByuxpX|zGZ zgOw4#b{wVg4nh_9i&JV@Kv6>BM+q9?VPVQ9GjR#pqS7n{nrNKzC>1?#G4xK-lEF&5 z<1IDCt>}OH^yuE{Hy+8lkVlf?@v95BLI&3j9Kz0?x5W}+t2 zgPhETPg$sY@!4x5CzR`Ll6_4!`CfydXiId%%obKb(_1uxM~f?a3mF7qtsE@%mf55~ zP8q;Sk)Dd|hQ<|X1ukt)H_nyIwhmSYmlQIZ%1@n3HoaYGv|y?~NF9+0&0#p3SHTvf zymnN?t|c{$d80)>Bd6}}qR0wBAY>y06l&Ap=r>m?3y)&4@teXOr5p+K2o>2ChXR&b zK~lq-RUuHgkV*pan@)RqRHSt|S`2fo3!Swo&+e4@3aAb;4^*eTi&e~prtg%QbrdX= z5whw78)T&EwC~l)yS1dUH|pe|aJRZGKhc2UBU;lz9Ga+lWa)%sJ zniZs*9Ibw*QBA>9s+>0Rw7RWrouJy!_L#LAl1Z|)TAgcIc2H`yjcZPFv^t$b<#{Wk zF`cc0^;^#<1gF(&StIyPW1Uv7jIC`Wmb;^Vrxx4Sn0$| zC?xQ-&kuJ5rs(H>o!4nxs#mrekpy70($;Evk$k|DYf&dwbD3HFsID};a!5_Jx9osJ5scwa@b7pUAMBUBxHSxvr**7 z{t*;`*hcCpP99sV3?ggqD!9Y!K?H>sLCP`+$}6#u5Y}Ib99YVhzdnH|9j1}Ejkrpj z(bG6Nmfp!CP1rn+=vwL@Y8eJBa)|}K5hN@P4$pM9JXr-9$JOC6#Ja%%B5TVite|qR z)lvsY)f@NO${2MbLABhcmWuUNq|@2m9uL@>EdKybuY1uQSfyS$Yu$bdcI*2RR|4kN zM+oD7>>|N!BobqCZ0%Sq?v`+ANPgkU!ZDN_T`eWyZZwQ*7%;U!pD}K@6i_e_n@4D7+EMcTM%gt-}F(O5hhO%1l?bL%a$rm|KX;a7Ae$A!3j0FDr|OgC_&)vvQ# z7KzV1E70`tGFHkhGkY~WX^JLA(_Y;yZS0S0S||l{>1iqYaus^B_EjOJM$y?D;vpr# zg!o-cn|MomN}c2)F5HeTQr-1(lpRUV1tsGs;<{~u)vVQ!#}#DCSXQ^F&M`>gQk|&0 zAbj99?NfB%TZmwGm^S!Nu&k@P$^5q3)l09o8!7!#X1W?2 zuPv$6Ii$1|8WCe#t-GQYlT)hL2`42+qf@jS$y(W@GL-X_>p+@kO!9@1T4Nb~lbId- z7fy2J*O60UYn6sF+MHnoaEcaFs+Ib%zEZ?EMj#;Yu82dA*o;If>0HA-d6Wzv1TWLr z;cR@R4uwDKXFrqmrJG7Yu$cjpD*~6K@wZ@pfTfS#JlMTS$fns@BP; zF+4fEEuM+{*a^v6$m6>0S`qGO+sb@G#UX{X?}BM*190I|dJxhbjuV?yg_PirVP<&= zqhC>UQ!EgGfZ&7@4pWRNqYE7#?5G43LikJ-WEw;baE!DO@k_0zIY$*f3^<&taJU9# zQ8#IU#59E$Ei<#cz+6XUxuh(|WT1UixNj9{ai!Nn;2zv0vD+P#bmkfsrR@$1;~Fjd zvJ{dCNDkO`(mm$8pi0S@PDf=CNwm>S1d;Zje0P;Ka+|kz9gv%%3U-~XaH4XGNWgcx zmsc{R_SvwFuEi8`GO`tSgLV}_vl*oiWs3ik7qDrzF<^EOUuA76(|a z5>G`WESn^PlAZg8pMXpjMvhYhZzUSZ#EB}A?cein{vqMYpVh>_&D+FP1Q zwGxscAq8r5!oeusHl^urY5P-N+vPZoCKqOo6&*oyb?tChK%xqxsB!vhR(f=JV`AXv zQF4Z`2Vq%~GqI$4Bn#bzSf}K`xG1PB`%zG=Ejs1a5p)!0sYi5YAvDBcKpCYKOD>8n zJWlCqH%Ak9IcxRZZ_Bo{`l|Z<*<9NQ2p-H zcj0-SodYylaU;Ifem=9k`#W&3lk2NjlvcU_7r$pz2XP034N#I^Qt-RffD!%GNhf+7#sWNw{*OQBt-IcHv6Y5spgN#Un>_ z8os0Q6pQh@>I5p0~QmV(Gl1eJOrv)Tx*&}Q@ zA@c^Dt7LoS1y+2NP(M%(0ZDwnwJsW7@KVIp-6~0dRiCH>x;S#SEuQ@l%OGeZm8<$* z)5^(pqiM3257>IWKx+%PFE_60LGo-4?whXa0rGh)IV!fRrFm^^uDs<`hQFhLJ*|=1 z3u_*>2tWm`!q78T=;7GFZB&{1*S~NWC{%jq^|DA|G*n_I771XvWQ>-w>bIPNRX>pHITXq35Abt7e}>IK#T!qZJSB$8E0 zb*_1!aD}aN&8yBF`|?Osq{#aV85N!tw!w9T)@l9=qdqq#{L0MJ{1clMc2=1Uc0^g| z7V|D$`93QMy+VgWwbrhPeScyz2If&!$2`@v5VSrpM=6Fn6stwLDT7JE*nwbtQ5Sy2 zQxO~mjgHL_trq&IELuhyijRvl5io7!D?2~^s4>~X&vUIjd$13Ivo`@#pG88)WwB+j z(4VK-xhtwxwg(_mOi#<{(nIw!U44g_Uch=Mkr>r1AU&`agKovUisSOgTfo3E*O z^4~*+VmAjik0}qbE|Dm1g(U~g|k;|44xy^d`nQ}qe3usEK|1C7yA;amjjY~WnyGc8Y#PzycoQiZr1 z`6a!i3krC-zAlg-9gM1{(G(9YZ4#RnoXYX&r@U19s$&FSbq%^#^%^4>+&C(!9_hwa zN_GK{%{dwM2@VG}g^ZB3)Kg_)#9E4=r*`3mfxzIP^J@$&mA>~Xv!Z^T-{I`70kMqPy#J0c>6Qk>$_l~yMydfgG#MoUIjew)zV zWw#2d?3*W1FbpfOb#}K%TrlDnHsjXt>L*>N9!|jFfVAxrvP=R%NYr(C4CdW}t>`uJ z%ZE`VbN>LYt!uHb>YL7qJuN{+S~ z7fapgyf?q=tFQ2ej#g8gKZ>1h4EU^a{W)+LD%>4PT|2OCd)9apObv$$KWcCNh0!-I zEw9$G{?0nH$5(awZ(cp2Cq{+0LKs1Rk%Dopwi$muCKLhBX*xmDY1b|9eJG@ z@s`NtWP6-wtp`$?Pc-}37PlJCq$q%JyzaA6<~WX4e*Isk63ux6tu*UB(r+q=g3Z#G zvuj8y^3Pgr`gd3*+e^Fpt)bV4{xb1VInm8`Sz88cTq?@btgF%;O$-#bYX_>)j-qOc z7whcMa=BAmX|-dqRb6kBXZNtsehbJ_gS@->EP;7hh1e>>OEqCqx@Ehx}w(Z zwMKnI^%77xT1+$Fkj-CEl6Mde5@Ln#$xn+)O&VJmbf7e0ZJ!&D$>!peI$`ORb}cv^ zm6kvoz_jMX1}SSW;@_Bo07N4JIL2%FDGWZVSsecW*=buq845Oo+x4pkEAFsDnl;@P z*4i)!WP26{B~_TKsD|6Pauc+U5=23*!8FDQH{DcdI(#q4`%Q(6lo^2vFi3X zBMV!rh0(jrj`7(wbFH8N{fa7$9u<8g{$nWe6{@k7Jw9=3#qlSoWKhdD{7ufv+~{sJ)D zOO3UZ(dvZJG+_&c8${EU4$3%l;juq*87}Q45)B04>`_x&O+t%k*+rMN4&YrxqGR^E z;G8egoNEIFEC$7Ln=1r{+rmw3C`YOdw345t)a`vG8g>eS5n zSk%hrw6wQAs)Gxf`aB|0tgJNptRPwvafJYoEn7bD=D!q13QW37JEJ%#_j{^Z@&_tf zOVgYX7TS_jwQW8*FV;4~BkHpOW(d1X`M?~h?S%!|q+ufAWv23vWw$C1JR;ppD67d$ z!^lO7DJ*d^6y4tLR4YkijjFv>jxg@*tg`GBuBGp3Nc162zZc@aR&>z2=uLH{oOa4) zmF{;>#&VKaTuLy%CE`IBgwRJSuyGHftT%H|_TLGC>It+Q zt$g}#GS;tBCV{1`CYPXa1y-5>?n0Z6tNX5F%19=8$X2>$pH%Hg+j(2Bp)i`-y$40< z@kHJ38Kuv>aohZVKe#kHA*OGgsL2`5Q`@aCmYOM{SL+YA5)A;f2dKIy&>>^($Dt(s zd${#Y^6|=;r!SBISDw;qT_i&*X``QJTynA1s+o^O*M>)A@U{BAds~ZC^*vZNGTT}m zS5dNk<&L`6ayGN|5DnU=H&Ne20jJCQhQpaO+xtBzGk-nqqlZFvnh551Xk-{54iCnnRL*YO>8Y z!8y8rg3@8nuRK#^B%I}3k5BH5e+1^hJEyUoma`0h`Wl^9n2T@n-Ki;qrSg#`3^RoU3Ye4RgKFo2QN1mSr^b(Ap#{<~rNS=9!8S2_!Ul zSD?}GXD*a}DOjIw zeibw=nipn@qo4#z=%h9;y4r2rX;aj|VSk#?xO|hBuVeVuZ$D3T1~YYpOdhIX)HhKE z-^K`9NF*!AlfMz=@LPIlbDHccEQPZRysE7ryDILR?tJ%J)>AxGmOYM3$>mzrI^AJb z+htyBA0|qBDA-3FZIo?r!s5ffmV`mJYQlw?&5H*F4CejmKS8Iz7(yTSrNfI~S-AZG2;hqqD zv3-y=@z^<4q#IqdB93C*4iLPOZ7^A|a8rXVfcHwhqS3=429I>CR0G&5_#rGx0ov>* z%0}!enwX262uJw1u*dCFw)4%QjiK=TLY7-bdnu>TI8rvaa8$ep3kd^oG?b)y0l6Hb zWbOiw1t!hW;Y+>%7GDHgHUw5$ss!c;;aHz7=W zz4{%}qr?8x%J;0PI$#+|){WX?oLz*yP5m6DrTG5IYr0W!2kl!I zjt{CZuO%dpx>|T{FjB|>(qm3JTDd>ekwFNuXrsF{pfqkksHK2fBHZOj1B4JN)H$It zsH~u<`6{slAF*B2^(Mm}(rH$G6|2_oX(6j0BlXeMI?pWqI=G812A|NYr&Lt){H~7yYrg z)%6ld_lnc%`h|B|`86|0Ny@Bu$m%Ko08w8o;>?=oySho#%HzpRk4EA+PiBaN!rMM4 zF29i?i84Q%QXnO58>96I#*hxnPwQc?ed?3~Ikn$;$0ehICB4$An1A~6N5%ww%pRzy7!h{vL_Rx&#m~}N(mFRS`2f_gC6i<$?B=(&XV`%IX8r#BM5b7z+_^o4?{bi6;lVM?R zyo+uHrk}kmniptV>ndh%x+yLr?;)+7Srv;$GlT;@(E~_FIFcz~q83_?GJyoMlsDNG zrd>mYm5OaeWgH&BmNvblcU3gg7^m3&72k@ftcvPJ4iwc(Y+rMl?6hG)7(vtIWXx+# z0-roaHvl+9*9%zcwT;|TlH^5$(P0B^QoHrbGjratjX zIe4%1H$d&}z)3PVNb6n9&M;puPmB?jt za92ieT>f4~$(wv-R@_D@G&s4!(N76r-Po+%AR|pQ?22u5*C*g4x5V0-fA)pF0MwJ?+Wvve(C>^HJvXb5Zi2$X81Z z2U@Arw64?1<}hR@==B14!+WQlFxQOMy2TVt9$T*PRzFQ6a$2#hE)-GJ$#eQQP#5Zw z>ITzWgSE#6E>?y*{{R(Sr>yiBXn_P}0m-m=QW`<^Ro6c`N8}k>gpHyTwaUvSb2}>9 zkYx9}q}KC{1YR?XUQ zY)y7mN{bBVvcW_l-BgRUmC@G|Fng%1DecOGXgJDH?lHo^O{sQ54Gpg9Nd!_36t@P_ zr$0(9B1F?kDV7?zZt~txruJ^MK;Y*cu=O zQ=~I=@(;CJk-J|=#)p?Y3x;VGSg?V>l8LqjEu4~lOKcT~2GoIBQc?*75T!$Jgh!Ml zB;`oTf))coG0;2!a;nb|en<-InQgL=@!FBadvU6lNXIvTsyjk|6AN}&2~p)Y{{XXS50xLbL`v*p=WmK%z{M+9)A$gG8QMmbRCU;hB5 zY^>~2^0J6TE8?E!(OE{3Q;GFXNcQD68ZEUli?D(_W)y#|{Z<$OQ?`njsoaF^pv9P`L@$Z&``cm_uS4hM&H_C{pe`<|{4igu zfE71MEc|6#)FmUM|0#LuL6dimbC-*p_8+4f4OxMT#6v$o~-D9;)8$iV&H) zyg8|{=Oa^3S7*tY2eW(FFN?6whHwwU9dC7}qWKT{%mOiS)n_A4E7UKIrz-HJWVao! z13g5Psv<$vPNFH*`?>6Dpx^+)vLA{b0Csb=nI0vk2$_>?_RhWp`yfvoo$<8pMm`a& z5U(_plPcL6?^@KLl`g1-o-}})enc~_cLt_^6jJ)d_PNRPU8OYm$|;gaqT;lwH*KvxrlZ> zGG&yx61Z+HbfDE}f4*=WIr4w8Hf2WFOW|$r264 zjH<{_pkvMedqR(8@(`i}tcNgSdfb!ABv-j`^4dufQ{vJgyzVWU__G zI#6(CLN$6lnUbI}$;y!~mIeLTkoN6LZ|pnza;%4mEHG=woUWMmvD+F zv%F@wej%@Yahj7@SP?42>nzRT;!<+>10;7;vmwI4RS(G?-?O5f;%60?wr&L0*h;G2 zie7vXk!{{*90#WqQfG!(*jT=80h0sotl9fCy3`!0cZy&6@j1=%K5(XnSXEhjX~w>A z8uqmQ;$@&bj`SJTiClk#xJ>sd%g{wf389M(+O9y%y5B zS=cY~iJju`?eW!E79tk(0qq?rbdS%u*7S>6{4W3hc06Fuvf=^bM0(L!WAkCVA z!U*yb6UQ@P6StEqiB`R$ZgvjSE8%$D)i$-ON$Vx@B_KI^`(sk{JxHhnBRkC1*e7qy!qZ%NhUMk?X~P zf($jfI#pYzrVbM}7BNXm-#n*V{#1Mu2oR6Em%QxvmO4jQKBBD~ zYm#7q#VuCswktFLI~8}mbf2*t&*mhFal7=tJksQ2@o%xRz0zAXD~!c>zRPvcTieaq zta(nUkZl5AE7|O%50!`Qvlm3nS$eB13{j4h5 z%QP42*n;)R#~SS+{B=M8W7)1sq~(RoNRf4P|U>_-FkdKN%e}p zCfU{NwxA2<4hYF#PLeEX6}#@Nxfdntz%9T@;}W{_y8>AsE+4d_1#B0i&^v>5z#Czo zAQsosgJEA%c~0>>(}B<|$y-SnICFS9lZN9X-*Jr&hH*KOEwAwaaYi5WUp^D(2&sq$ z`_J)XM&)EHsHy%+HFg7~&C?jf46gX^AsE=M+Qi@K(0)kmOm?XJdLwPvrW(x)9!44~ zy!_gyvv}4o_cnkXmF>tao^p6N;Kk!E!70nxLWkCc(ShdO+O=iDx3-S#uWc;$W+V#a z20d?Pr=%ec2+9jVGmUujvR$SJ>f(1Zr3;_g91K^ibh?tb=$rE+mZWTcxb?~6=6yT0 z@1;dEfeOyBL@vcN7CN;=G^}ma4*1VFV#Zd*_VQgCo9fN+6_>Ia&i-`{yh%B`FLZ5q z58xzh96toE_GwCsH!%r+m+IPJa-`*~+n(ZGpLJYvkF%v#o)E8l(0rw%%4Ogtv(M zo$2{De=m76>N>$Ut%&ur21js1Cuo5pSyRw>;2D6kcCszyF3LjS<05Lnfr(P%5bYB( z3A#RM4$G1eTpq$YY8pEXw4?O8sdPWInC1;zFPRS*ER-FcIvjVdRl>xzU>)}%j&k$^ zSO1CD7^ib(R-Q_;PtZfk#mmv5hPy}ayPMBp8&%?3pbV-$!7Q9-%-ii+n`8V32xAcTr|tW=F0^ zkt!rR(W>F15!%siy*IrBv^9rZXhlGMZ+3xSiXgQ)$of)~m68FEkCKE^;sqqj+MJTD zD`3!sXR3o}&zpx|QBpKA??BxGp3crHd*B>V6L9OqM{NJ#Ae}nV;XM zIxo3MW9jZDs9yix5cXc21=|IzWiu!9LL{3=|GJt>!ccD70YqGVB-l&Fuuc#>DM_kh zI_q6SUglbZsos88Sl=R=%%OH@*`!*Z{7~$)ycBIj+!ZGzeJLXT1CJ*6O!n1!F`Pr3 zjhUd|*M!I?du)%+slYucW;NZs2s}rdn!?NJfVRP^TIQOe0T?xvQXlLICF$o;q)S+V zEbYI8{#p6sjuq82Jf-aWYjGrGxSKw-EOnP4l+au$axDQLk*~r3s-pSb$6h&}EQYyt z!hHi)FOiD&X!#wtWh6|EYX+b~M`TX*M;C)cubec8ge57xL~UG=2%7R}e~}0;O@Se+ z35h1NDfQE+w3!iPV2#r9F8@sGcB?H|Q7B&(FE`fvu&{|rr+CL;=@Oxw$0#h+kJOwNS<`j{L{MoIH2OgZ zc(7%{x0w#CduoVAAiJl7eaV>CYdNAxDQq_Eqkx%(Aqiuj4jg74_ zYaa7janEM)gy_L(5a0H*s0Fre;j}Pgl*ng>>b5jU_(?2tcy9w-qjeRJ#?9j0DJw0V zrP}hg+N5N$=M?Y!ujMKwGBf0oA ziwRPExvd+L*DkGq#V0QCr_`w0O$}6YonJWB+O#B9p9X%afGveRfY;id%FW)Ub_^aD z-c>(2bqv5$+$K2eyg}pwu=2yu^Ah;@wZ`d33hNL3ti`FvB-bRq6*|Jh?EA^NT-wJZ z+?P8U3gxptjNUt9tJWUpJL&M%(ZHq4UDXHBK@$*J0hj!jTgEfsmNO50W9h*Bpz~V+ zS9~OJP23sr&lT#8I8CKrFMbr4M5wpz!{)qo1C!YI7h|3QQKv~NixoS9xGHBuS_iOl z-icvqZ>uqEV}&5rLz^9k9xu4w!@>iip2oq)Zy=Pby_(TS*_wk`EX}9qZYpFKR(vS> zYNnmz`+;4$R^Pp9@v74sGnGZMW`E%B@jaocc3XmIU>&LHoZazN#KXHOar7LMRrtp( ztoa)WX!H~HGk}fbi-To__99eqS7H7cphB)JF^VlW+m&{3H596vb^+(U6_;| zk&yPgHUrk2Vm?*GHY>g3!e{F&DlqBOP4Be<6+31F(hH-#}2~Z>y{u z&dh$jOHrNm(fGyO%l{0B2^oHBtpcW9H-t84-4asSin%@UB9^tXo!q_QPN7M^#DwjE z<!+;ss7ZD^9Z6=s6vQ$T<;0MzDna@B8>0-Pt9O@PEHuc^|p? zT|G}h0l_~o^e3~LvNYd^ws2Z(p#P#S*N)C&z108uPiQPXzY`Uo7Q=}J#cW%RI$Mso z0Htar%`lrop2+73Vrc7)s6H#3+56}g7I`k+`UmP9L{?kYJBEYs56eU@p||=;dm)N* z8=HPI`?*91l<(q(i<35bHi)Z1wiM1FgjYR&ZC1D-njox??B4FeA?A2}Y)vrJw{jF) z*~1ZBieN#5RZ;$G?|AiB6o4exRQteY(~xtU1H4pmQyO(r}`< z(13YN=MaDB?Ued0ZR>V(MYoGyw?`>d-fYupMb|3XFxUGWXxAAIiaC55elz_~^`+Zw7IRm)k&UdGw+jSU5G+#Hv z^cZp4DWzf~*qqfLnCjCI+VpS!<~Q#*18X(JtKELrj7S^JIK(Nuq#S8qQoWzDexy3+ zdMXQk+y*~odKZw%c#pOO;Yi~V-1w)deW;i!G4yYuW7bycYSubsmT}dH2mz)KO|>`)gEeaL*Lq#M4Xl*AT~1>^V7=9-mkRX0p==dD`$$hzM$lFN)`*~ zU4#EYILcoFkl~@kLA30?QB6ykuwaQ7T^#^yBpk5w57~jdA=K@&=`(wLet~PlpMrGIhe$uZSoUOiO^&pkTyc(9tTtCdfJ6Q;*`vkW*_UEDiHwnJvyJHzW7Xn@GO7~%U; zqD#C1aRc(=DquN?4qm|Q&#ipo#+@xCDmK1`e+XmTn^4G0p^tA8mK=gqvkP%) zuoq2SO45ydtdC4OnHXk0Ao%vx%U0NUk@TqC5E5X8QR8M<_()m|TV5qD8hxovdd>~K z82D8{Ci~<~?!WO206hbiyFX8bTFhL|Nj(aN?k3)6 zOtgm{k%rmp29~ZKE9|XY=t0V7#BbGM0>R(NhY+vJ21o7Dlci7YXiAWT#3^DR!PO2y zA^CO_pNQLY{TOq7)fo+T>$`=Q{?(x5oWDsHAo*ZyquFkn*sx7G7{6~rWp%Lk;*OJS z%bZlC|06$_aqijLVprzY>bylLlsY+Kjk^slhXm~um{B*_$2-Us+Xdq{4?-@?P-Yyg z&ib&k@v$MBI%QZP1J&Km@ zq*3IQhr8;eLp=kq6tb`A4dT%3uPurB>R*Hvp8@x;1n%=qrVg%PC2kQP6~&CC%-P4L za#sYh$%Fh%`xk;(3+x!F=-gG6Zw(9@H%5M@SC<`-n!f3hFA%J~SVgXQ2H<&T@IC`p zNrMyb!o9|wU#FCQgi5hLIb}Hr(j8~mPy^fFbUbvd4x6HtQYSpZEzH}HQvfL@G%1^f!mf-mYpQ4($u@89$QW%`)OA;&|KgbM8L(I6 z7=n3j(@h)JZGzY3tIu*HF0rs0g5+b8iXM#hH+L@3L^G-8RVhAMh zz@55s{DDYdy6GA4eeH*7*(My%s7-^1a@2HoJ7pT2~UOEsK&|}02k|&~*IIi0HpKkUJ;~nc;=|k$xi8n`J zc2+JPi(Hi1!XsWXdB;d}=;5Taxk5AR?`4eF=gmRYC*#L1f>`S`1#-SRCs+_Puo>o}jJ(YhJ$Oo#dA0{qb0rq{IgXZzq@JKF_`CL1-dGR$H0 z+Em&_Hrj*X^txjFP=XO|EV$6jh0M?76pf}t$mZE@xAkX5e_A@EbsJUq)3tJwMU2>t zs%(gHb-{#IH70VF?z%`+fLi5?eb|>QDyk+#N26d(_l8MGN+jPB&YCWVi>=I(9Tuso zrsB^bd?;3RIrP)5{6+OQxj;6=@i!Jikl^-c*ZBvm4L}w-og?oDLVCkcpWStDY5kOwS)EM5cEg;r7NoU-C?m0{0Pl_2E5oEjL)C{ z2md^`{ICg=rhV(D0;8w4+Zyw_>Tsi{7V}9@_!sR@-njjKPyWw{ebZF<}O(0uvmXR$URa4VW1=_pG~nur^5I@VI8Yj zd<7cx4n=6Y7_E8JcM1J#1vQjpr)%qId|+$U0ni`!*{ihWE4+zvy6l!;=3qQHEoo*n6mOco*$c|AjhB zbI(KJ6C}3>cjy^l*L0%_Ul=W;(jA4TJOg|^53pm+Ypo^X#LH0nb4($OmPL<5?7SPW z(uNCUWd))CRWU6Un6Ow#CUlo-=0Dfur;qSXsF?92~$7CGiquR4%k249bj@}J-kWIhe zS3ZnK#vKyh7no*lLn!dI41@=8m9dA2hQ!ZQU{7*LTfjubbNbDb=whb@QUCd|W_3{Z zlDx*_IQ9$(D#9-OcXf)w`a~@mP-`Usr>T7$d>lEB{rd3wVe5a_UtY4}2pOm!Yc3#d z)9;T#?{ZSWNW0?WM`l$l?eWxEnuJ~Rx_M00+-o}j2iEVO{sU{c_J3eabsOb2{-lUg z|37%D2XW`F0@-7QnBeq_-ER_(FnrBh9+il@rjExTY@n^* z7}9L;>muKj{xk07W3xD~LE@f#xoz(OB3z@701!^2A=jI9!gsN|h@1-G7b$O!4X-VJ*5fGas4==%)Iy zmJHP}0+Cxwo^yl=WV81)uDI`a)4xpm%eOt)qc!2Junh3BV5ox=)JT;rOQLx@h7JsE z%{>E3sT8?R5Y>0w;R{USL~cPBAyo1J2Zw_w???J3^4H+#F55bx_Q64oZMd*Qn;h7MP)?=yYW};K!gX<5zrDoetQ!ub4Sdrr zc~IDW+S=w%)d%F=sfi)Dy!71lc$X(U(iv;c2XV1hpzQ970OM!CIMpR*zV|a=o-g=% zLgLlh{nnH16$j{nxI}>^NBh*b@XGa`RZ90DG^i9#)-6Ztg=>fKOBQYX?b>c47Dgv=BcIfPM8w=S3=Jqff^;7^+QL#$pCo%++k+*S4* zUf`7_Nzf)(Ou*!h7d%KpGwtL@aYmUCC{&=HdYs9-xGk` z+m&96mqVLWK?>dPUgfl;&2`cqzh`&LC_eCcSqkX7r?)eG4aF^pKqB-kV1CPMc_fnQ zuO)YZFA#U$^2xhPUmh;&5zgafHDo*Fy(8}i9Y5(RjOb4J5k4tY8bN}7UaIP+elFYa z#8w1g)~UM9NSp<+pIBU*PA!B>Odxh<5{KPat10JBC@b57I3FNuPZD%S+=92%Yu zXfZ}YZcR%t5eb3&TCzNeqpAmzhMTtd|HdR!&wzHGBLpYI`jhEjV}~)CZ&tIWlsXcT zO$luK2KY*jfR~=YO7e>ZT~4jbKa=HXGY72hHB>;L9qNa&AB~rppfsip#e~0uO$H% zeZ)HY-na25r|qQ@^FarjL1R0L{%r9@f!)A#jlIGHtzPUvW-aY6lHYmkyXpdkmXZc$ zPD9oRWIbY?;^f}y>BNUe{|(y_ANJT(RoDT2H(hWKwM9*QJ$^&=-E`SKY+zT^<4%I} z8ITim75m43>nsG3ogzYr5qBd2M{Jy0nCg1g+EA_J+@Q^9ya_PbTB2G-5BbKCsFEFI z{m?rOMa9Uxp}8G03xy!u>N*a_BZbq1o=dVz)@Ev2Y$S|`DC_3;K;*(HpF&*=G8&}% zZ8hYm^GGJ7&}rKB?tq*6J300cbA+qx2l7V1aiS9nd?OVhvU^kv!yA`7_L81EmwQgr z#kr*aMsTtal{sHYtBqXoN3p3yIHLCRe;6D>iEu!aCAKAfIl^y2|D5XdpO3^XdBEMN zkoX=1Ejt*X$RYrYjlQntX$qOlb3#WQd_CypV_c)z^3;;OqK zB}&+PM`q;&L>89x0y!=wvX(Zrj8(TZ0u@31-@*>rtk&}Gk=d)mQ;{5$(tM1>9@tR zN<|z6&U1&R*&KmajUikWKz*$hh6qCwdvs88Yi9xRPwdy``Wjv|0cS&UA4rl}bX4o6)Tx;;>bR%Vd zQqcY|#8l~@-si4MqaSExTBLYgRJs5>s{z`EZjuD!k-nS6(>X$GZU&J*6fI_0)5zI8 z5@YxZstOUdi*`@ay)X#-SH^If;}Uu))514q@yx!y_w-_8$@&!)jdhV~qbv zr?^~zbMlUU6F>63LCr0}0w0 zuF9Nar!rqcr zQ|L3m6GreMWF6Y=vyZ#5Wk*(-(zDB1!0Xz@ira_kuPdCBkW=;#lM#;d_h&#y6VZ73 zX_8c1VweP`kKKBRrBCDtqV7>trLsBc( z$Zcnh{*faydXxh3x0U$a9@g?lWE5O9k8y`1l*&Gt9}<~IHhv~Si(yn#q;Akqu549* zvyP&L66f5eOVii|zf~l}`M}?{&@bzZ#q#I9RAWB4wrl=jk88=yL+uC%CM@z86|1pR zr`mZ2@NEf)@Wx4?9gEfA-uQ;NU7Fzwy$mGQXUI^9Iq!eWzNyXNYg-N*8Lb}X_BXG+ ze$#m@?wRUn67Cl&TZl=qDg zfrO+1Hp1(sK-N5ypx@8Ol0i8ow)oxlc#s7#{%W`p9ECnKcK5c|+K#}#ac_ONY$0q1d83}S z^)$0Qh1u19CsAO{46>xY^5wtm4CTcD-O%0+ncr96r!A2X-Y~G1I}P0#U$Qo>P?LyT zgwE?HovmrSJZN-KX#m`(9q!a3QnIc73cKccn@!qS`fv|~Vc(@cs;E{MLBrZ6khW`KOo_9epO?ZYj zI}uI~(=#B=kLTRPd=QcLQ+}=_! z_(qU(D~s*dXwDB(%d;08+^00 z+OEzWwA7CK@-e@E%da8+e88qagTe%#BZJiJ9=qh(BzS$rcp=eWeCr(P)w{#4{-ycD z5^;!=Iqg>m;@UgR@NU}WOX=}3yV}83iKZp+B87M^_cj-w!Pe(=M~iv%b4Y*1rwhoy zBJkFz_Vt1Fl=oM2ntYv{@&lL{-AAR5NQG%5k3L(iutJLb{ZkH?SlVgAzu@Ym_2QJL zr;j!7`T3m~uWR!kxu-1;DBdx@`zGg0rb zF0E>_PK`C>GbUDk-C78ZCnG1{6`=;kPTi#pz;M2n1O(>|8(_ye`M?tKt`!b;x{5L# z{;->Q&4#G{YCJmem-TkJ}5Y z%5v{DB|E1TG`LPy&1VPDduueAI`uDdbl@fGQ?9?dXnw6ueuNX>`)P&eSYtt^$lH`t z4l1Bt{Bbs1`A&Nkx}7OVYDyc11W-Tp0qrPi@t`M^a-hDe$g&uS@|S)Zu6Aq9(IH`;~@a8d%_DD$>kzd6I(~(Fg1baz1wOsk?|3%oKKP4Wn zQ5=Jiz15jd3{zb7wg_UD*8I#3X148azYG>^IHw)Zl0-q1@@%UzwL+;A&DdI8nVsC-5|O6v0uw;-%WdD zMA@L+#p|@;zH^!9e*#}NjS=<0bOf#;mrp#&%^|-6{t9Zr;)Wc0N|{4bJ->l@TMaQ$=15evx$GXPcr->DLdD?C&Dm@iFX+;e znpY2y0CHnaN(^Bmn`YDP^K(8V4m97y`J=Kq9ky^O?O`8tU_joi9e3I{CgDdJ#&ntG z7pB3LKWuyG(vE;JTL{}$hzM`MD>d9le$Hb0Wf|>JvvbF+bwC!XyKvRmvHfI3#CLdQ zf--QC!5J8h!r(a{6RV)PN@&|B_80-0HCP1SWP0Z^DQ4|oE^#O+ePK|60heH3ov(-MjV zZ|KA3zMuDi*?DgEguR!tcU?SOS+QsQS#2nAnCXDs@)7Dac*55#6n#a76`$faq-JDv zJE;YGnp{3M;FL|i%#vaPrW{^-T{jFDEhBxnK{TbAv1AiCr3+$3#_8UC3>rQsFgN`( zR<*%&u<=+|xUX^9N!P-2(>AquXEj6oFnJ>IKoUKvNR^MV-tO0g%Myz7d#4qDR^)G+ zo0=u;#UBTk#39BD#?|UOg0q>mACC|io;%5SH@dxDTx_3r-1Tq58I8*#7lp_=fm#hd z-9fV>yZ#BT*XHk=rJXFC(p&T?{1g0exGz}{^<@rD8T<$2QY`YwB>YkDzalzLkhQ@CE5|TK>eyYQ ztoT&2ze;NkV)?1hB09&l4jr2$=EM?soyMd@430xsPjDAe|M!?5<7%H_$VA{Pkmmg-$%O8T4%%I$xUhD zd-bdSdqc@H{>Ib07Vi&Bz0|EPbXPdlZ4|s3fzvXsVdytlKh#f=!>(JOhUNACyg=CV~}e;sz~kEa#h39IRfr-W0qZ#34)b8q~jkr2#!? znhMAqE=nGoijQoFB9@$(je~!UUU?%kGHaKhde8BLmW)=BRkvdQk8{>WfOQA)dUXqO za#7dt!hH70-(1jUNY|v1AwgXKqbt2LMiUi&o&MxIiK>b$nS0DhDS~)nqwQ7GCr-`} zPfeoor;kPYh{w`65NADvyL8Zbd?pV{pfOBQBCv6PRa7v@wNC%N>s~M)SKY%sL*rI) zYh`e@FHgQt=e1P?$xB!w|P( zH15%Jeto{{4NGV&gLTPhKUwMl4nOafIfnjb2^3xW$j%yQpM{cYM|f?w#+(Ex5Usp<9{EQA*ML%KL~i{4eUa)mHVje{yv*0_v># zYX&2l5}4TAKz)QN|ICchd$j#63%wczygDHLhT*d9mUy6?8lEN|1Nt>3OMhrWnP1)m zCE`vfftsy2Zw$2t7-Lx89(=dQlT;``Y_~OKwJ!)Bkb*9;oQ!9d`A2J~@dD7codqVr zH+Tv*@OPXbRc+obX+Uti$>iGK_T+5g?X7=vVh!rOqhVYpcxnYb#mlFCs!qwxFRz6X zPUX+2s^>b;xj)*Gu{!fC-7%vwin`Zc@J3cK z>{!pZ?nfLoPdg;4Pq(dFZwD1o#i?oWB_49McnL?9MBk09k%YpOa|-N7of?dW-qkVO^RwnX%c#?9+H95^@7iGt;5F4 zn3&qi{jrSIH-2S${Z-bV^-pRNd_@l{bC|B4YG(rpJMw;9fZgXhp^J%*A+&KWFgE#a zdR;FmH>}3Ok4KRLgZtw-;HRcw`Ou3byG>now>X4F*0;0)dEh)o5{aB#TH6(tcyFnIQUHBR55tiRNGHWGt{kPpoElVFe`8FL9MZ4^*|UQV!uhmDL{cf zq9!+>p_x9P06@d3#@h|#Q*+b)>^ya77hf1Q_YFU`9g%x-|4-o%Tz>^)dA{22hL`(A zL9i|3P$+43=m}A#Qs8X)K=py}w@^jBGcYlp0;`#yy-$2vk4_p1PIiC`^V#+Mc-J%^ z?K9wrl?+!WgnsUT9Jg}hDj}zRjMi+`L5tSa>GOw-lwoV5lcW<2SB}taN{#QZO5+B0 z-~4!TUeU1T`X2X@2o;-lNM-8XV>);#aN|G4W0Gl~3ixj}SCh`%`0~o7TrS4|dLG>2 z?S^1St5{vu{w8fDW~y`-epQbA=-*Qx?y`gGpsfbzulc6IR|(GdKC}6M4rSs0EtHPA zaZmR@?O!s$1^<3+|I`Xa1b3QM&Zf$ijha=Ndbvis$yf_ABx%Hx5Qa}Hc)Ua(&d2{= z9I6Jxm_J9U@+QQurB2D|`5*cv`hhVI+xz}5_#9E3LH>K8BKIi{8uUYdhiaIPuioT!|Q& z^&T^p`RvwV`~%x$+{gNOWXkZMD^a!-ypr(x<1fS3vvzvV1CMzA_g}V|I>B=`eVNQ( zg}+Jie$jQWA46-#@BhMFsEIFNxm+U91W9dmb74tbA*iCk2JR~oCNBs8w8QCaO|%s2 z!pAUDWubVUh01@%)@(6_G*)#)(wu{)q%?gC&gw2A8P}ugT-3MNm=ltiOLEG&!ELj; z8Q?(XMGPcW^Fvau3Z)7g91+NTeZCXlB)Esh&d7>IWrkAavkPymG=)fP(a_@8Z-eg` z9ev{fo6OAmWng!)Rbr6AJ?a{^F}rMnjh!oTiZ0LcQPr)n|0P*&=HbT>tg(pAa_`Gq zba>3f(7#%6z|g&-up@Ar@m4y;(#L<`zDN*Z$TLXBO(5)H5e^*D-~4xl9qN=jOh z5`9Od>uo!bmmL|#jE(5ZSAaZ`ej=!)htl_vYpYq zVEoroznQV}DW{={Z6H%oGqQeG*2?UL8@TBEl6~)`cDC3y?`w@Fl(ZHTunsbwm$Gsg zT8V~U8jnhrC7IVqf|RyOj}9f~Da4t+L0)S2Sd$BB@SDTE){?|0FZ&5Olwlvr*UCat zc*NPNrYoNrg)~aK20V!K$jj-BQ0-#X7^bt&8eiy{rY82h*(Bt}YN8m){WX1;ZGO*3 zp+>OdhD58*!8%b5(zL%&o`i(#wQ-JI*9WCPy_G4XEZ-6`8pHkyBnWGkLBV!jOE70s zBC;qJE0|KpRAWd_irysE#vo@oV)V3UNa1)JqiH>ftsD`dwq@V%#waP*)!n(;rs%vfP9;@>5#$3kQNFX43CJlkm``Eg2#* zZ?&fHst!gF9k9PJaE;g##2SmqTG)0H?Kb7^6bY5YT&)NOEK={k*Qf8SIW?>*Yt;j6 zy!B_hEB_97q`zg^_VSv2l%e-@}lZQdyHNe^8JDH;moYcjDG_G<*i?*dgf?ILbkXz(OFRSFCt3RY=y2kOm4Dh)Ff& dS2m9d^@7J(^kE2p9E}4+e8frAy!LtF{{hqDU#$QD literal 0 HcmV?d00001 diff --git a/images/training_pipelines.png b/images/training_pipelines.png new file mode 100644 index 0000000000000000000000000000000000000000..9bab0bcbd03f3ed6154dc2a0de035005d3b10bee GIT binary patch literal 103185 zcmeFZg;!Nk`!%`|MU+-TT2K@L>5xV`1WD--6zT3%kp__lN$Eyf8Z1gcxl2SR6P5fz1MnTK6B1xprV2l&P}qL2!i0qNQ)~Y$Tb%PxpE&91AdY* zEyjf)cM%!!r>ZWAYZIsx#B8H~y?EUmPFP zAz%KitiSwOeP&v9^pTkTjjFxR3m>uKs;^vm|2{RX^yU?G4=-<h+G zf5Zej7Cnr+c8)zJ7wv^wWH4z&#Q*1o+cGjqdYAUUue8#aZ({@g`)Ym4)+F~oZ=FlM zGfns3H`3ZEVsG6d`0tCC9>XQ@zn2A25Te)p&t>AcL1=paa~bXRUkJ;8uObNdG5r7E z*PYK3l90@v9q+dO`BQFIUcvp+tE(2`;?dexR?Odj{7}-+ zFxf0nm1VMS_up6Qcyfd0Lp{e!StTU`ud{6_3O)x+sgD#-$0{ue@bTaBOuC@2k5%#1 zOvXIc#;2eNd53;|M*s3e{o2B}9e%qx%p%>|2bYcZi(N|-wN+|Ww#&VAiBD`rdrHqw z5B(Do64a`!+qJFheLsD|&rxL*WX{;XcCCy}j?gdn>Moj($J?c0-#B{VC{ zh=PKG{^qN*I4t){Y`Z;r^k}%+{vl%UGl0OPKZgLB&Q9?XH8W$-^W0a9T+HFJ7%9D0 zTz~Lc?cfVd_VbkA=Ee1Xa7QysJ?Y#deq6fMcGup$duP&sl}TiT~zt%x>JB9GWcX{=(oRV^u)` z6O(=e1-$jwX5%F;9v)uZPCFweCg#_XP1Tqh{P-7&if@fOKhtn=5j8Y41dUxrH&0z0 z@KEqOVe>n!iA*(k5*4QFTC--_&i=k53VIywmmh@`1)A3UM2Cr%Jv#tx0q7a#omv^-C++%)Rl)QH02>!;##zEIjHQF~_i9NL&a@6|z`rFgZ=sv64 zF6?W6zlezzz94_x=i%uoVQ84vxf>oGO=r!)z|g+3JeF1vH8(%sW}2OmL9e&jlP+b~ zAar?VA6)q;THfb6UQBjoc6P^cKxk+My`Wf@e1dUhSGscU&)!Tq30qs+>Z&(#uP3Ai zfAf*6VK{G25>fIwu(~MZ5185h`+hBb_2lH_CSrJcu#Q@o^_gG64?240J7Tpx6J8f$ zg`4*D^zb#lR#5m@UzQPkk1|pFdF*!bDCVl(5pZGk@bD;zSV(>9Gd(+tUZE9A z%3-Vl+rsX;^9pGyF>0|n-gzytD&V$KiWfH$aUvF1|C(&*_nsCao~~)*6=5`##*bO;~uFknsLyH^1}7(x3bQs=6wxiQa0t8nXe01Pgybm$shu zsYZRF=UvHy`5e|4XUCF@H=1X2IYTAy=}*mXf~Id zUw&*i`y0n~0-lUs>&@`jml(zKl`GBs@I*elkr5GdMwoBg^t=yVjJWUNt3{QUa~|s) zovao+Ei62hXb{_*dRl=&RR62Zj7Ro2vy(|5yuNs7Oootb$Lwebo9wKo)s z-e;YLZ{P0dd!KV5(@D;gPa_}5TO_yi#<1uY?Oys~k=Pk{)1oV+2;m?$&dySY8xxKT zo;z(1HNG04G3(Z3@PzI8`ulh6M}(5|kWrwoQC}Wi{~bb#jtnkl22J0&kS3}@xxKTWM|5{@w~FDPF>%V2=?MUeQ*XJLPn3;c%R+ytjg~*@H1hcICadvY09jotkBQ@jAUNob6L8N;A^PWcoO@t#J^FBM8K#%^MKd?2*rL_Wux zB_$;;E4kT_&psOywX>PAdZGh)YTHwn7bHqbN?W_T94Dgr=0ipQ{={=&;^D>a&aSVA zlIVL72)gZwL^BXIjtv&;hjk|K%vyV&CQx9PXF>lf&?vbgiuo*#6|JqU4Z2`6Ucz#+SzefR@U_^zNJo*v?%2NF{hPR1zI}UkEvmLJwA#)nr|u zd;<4BC>zqUvZA$4>wnXw@SroJ>()A5LkhI3a7RW)P<^c}R>*`Leiu;v0WEfW)BAF4 zsq{J?mAHq8z*}K|=N#$7&8gV#<+pDU3IP}VD(fi-X?|9ZBb#5#D8L4f_Qd-_!0Ynd zdA!k^T1ZG}rzL{sS-Tt4!u&i9Jv}DuU7Ua`dty=&0SSq}O1459%c|*65v72j;Py@j z%a6XkzUc!%dx5YJ;t~>Q^sd_Zd*l@sBg(jbCUSCeKZ+YK81xU3RJmB@rQS@eD_5>y zV-&aBB}aJ5oU#GNAq$nZdAeXasnk9=rq?5D*yn)Yg^_0E#b!(U0zKJOFu3P)xQ~ ziyPCRh2LQ{c-!niCFu4c^5OG#mcjP+_V-m)JWzd5twdd2eO`FxQm4k@27(GjSSB*v zBJax!C^|k0iU@M+_U-MqhbgJ4sYr_Vg{yB8Biwq?Y@)C?6)7pH%f<2H{A#;>x8NHa zLBL@D{{6$m!V-J>^qs(CSVp)N2vs%0*hlc)k|rkfP>be_+T#Etqa)B40VCzndurza z^3rhv+}76jV9CM)8Ii$?R7*0cs1Lk_H{0npx4=*eDz|#vNDbb@}FPr%;y3q7X{#p3bI}~mbSi3 z&qo8kcJ}tR3x7xxg*?d~K2QE2nf@*|mW0oFBjl+M#usS1#*K2ChaH23+O`0xvG1@V zp=4a5($bhto0Fk@jw=Rpt>HWP_l$mruxWc;y>^X;g$3{YWM2%j%jRh7rB1b-1j``p zhbH&{0y457avrNx=tqD=|5nGu8W)~u8yeE$QapK!s&Qe$7a|=okN>T%VnK~_adUgo z+F?USq%;{t=ao^eKnybqB%#65-!BimF98P8a&i(vb^qNOes6wx{=BF%@489Zi_>)F zuWt@t+QULTZ2TnJh9w7RwXL~`okDj^SzO6qdLOfk;!YI~`NcmUsstvkl? zapL3Scg_G?tzu^tD(8yZ*_GB=0IExLHcEj{3qrNMOYgm<@H#IWo3F@SNLPoQyErN~ zEB%j3%!ioMP2JtygU82p>NlrXRv3@~jyGMceHkr|r|_-p92^obpM+$YZ@jkG_t^Gr z+Z3=EW(|*tdC*us4YOg;VHmzQ6<$$Mk!X6)9bE;tM+sqJVa>;tRB2C7dx{`b0|C}_ zEXMUpRYyQ`-+Z}*La@+rI&pgxCym&+xYUx1dQAQ_B-%iS#0V2yCb%b-(%VvvPbc;~ z&aK*4tB@FCr&qf>VmFGxA9Zu5^sKCkJ$x}#&mSrzek%6<-QE}#MaTwmh$iwUS#WIH z>GrK#bE{uVcG)>ORSWbR^$>h?3~bgo1q+J;=+`{1%+N37kHe0H0Xp6SSb;y~u^ZYp z&uukMj|AXS@-00};d9KZI}g7nrB!#{8-?9tv%4VwZ+6ybr|Rh9bQ7n@Jh-%1hc`C9%sV*}=zjlLy=fTJuw_QfJ0F1yW=k=FAjk^*bApwE69|i%ET{&aU?4IZy zn(C#7;sir&CcEc#y3+th?$e=>oEekG@t&3zDMCWRINe3)WDgF#H8kS7cIE}2->>gbR|@XZGcdPo^C z=}FT>Y926ami~N21+9#s*m-sEG0lfh9RGTw$coX?|H$pybWS%uGzq zZNVDR`8W$TpPp_uxUq0S+u22&cF)Tcg zdaaa;UI!S5kHpqKdSGV8kl@xW+V?&_^WQtWyPsdZem!`4+Q4>wlq+doD&)?W$P2)w z;t<8drLTiAv9Y7DcA*O%H*phA58!U(sWPw`{=_V?n&2xj?R&uX^83}}j2F48J5!nH zb*SY#+WKe5cX4rH2uNrK2DAVC4(P7u5d!0T$N4D(&X!r2;I-6e*4L}TF9`u%P=v zz>V7vT3dd?l(P3eFpGH-Mnt=pE3>orDy^rOVfMcl!>r3NC4k&z)~hqwI_$$WtEsLI zLvZ;5dy8JokBBCr1FwaUy;-<%x!vTK0{UT4RPDLBkV9f-BDan_kGBn86<{F%A7 zNV{YxVg}ry!6YmzoRa6~=fBbeQZl z;dN_!%YBYs2~b*A&M#!URPuB%a4G0Fi{m-oe75|gX(v#qSFg3`LdIVrzdysiZE%Gid#S?{z>Fj*`#27(adKz7tqBf0GYhK z+btYi?|vA>Yrn9R!2v%}QdL!LpdLbUnNqkdKX>iHKeUb|Ke08et}6{p*xMH2^|Adw-yW&ozL+-}wKmZc$rgz9W6iy#B zd?8mnz=Yi&$rbdfBb<&Ff-N(IX%F>jd_!l}4abt%LZ}%m04i-T%q{_sRI--F>o`b0 zaEkw@J74^#J3D4WMAkPMndau^+EqWI%-^g(Ocs!L_cTv@C>m8xSF!pP^26|8ZNw19 zI71ktzfVo+ze(P6Hbu=>j~H7*Nbf-g{A%{O{%c6zyRa-2W;f0G2?l^BfW`u0V_Hq| z$hn_c^RBPHS-u;=8#CxzKsrq+;PTbO@?dKVAG#SK@F@7ue~HEXuzMcm^H`3)vLTQB7pf|mgZb`=%EzqivMd*?Qc6N3`ywr1zLJvCC#Lyg;of)upTNU3$?3)G|PmFI3Nt-79(L$Rn^-|X! zK>zSn?=+n~w;xT9=EHEEYgyMX6X?uAgB*tuT+1TiPA)HBewkzkAn#UxCVSnb!lxF1 zDlG?x3QN-a_g7HipX2~PCt=p1@5aK#<#s`mg}wEJkco+j4$qoK5By{G-u{T;_v!bQab|%G{h^LS7f@iT&*t`6BV6%4W&JUqr3l?_Gy}m z+Y~V&VJnG(SiNBd`AvYqTrL*(A3j{2*k+P5ddfF>es+f4(h4Ia{?F28kurMBqjU_ba)b1z<4gFm+n``xJ1yzI8mb zWM0JrUA7@7=sm|Lxi9iVC!6|(u4Aix?d|KE-yIxwW}>IZ2QnhI?<;4Hi2~3|i3fnI z9TqY(GXr5J3xsW2u1_pwdUnl)TF~wOWfssRGXNjlRu|{ymc}(k0SI?_Xh%njYApIv zjn=O;9L-?WA4^C`*j1Yp#1@dUtgI)H2zzp|Y8`0449# zjH2lQV)DURLyZz0&*IsC8H7clFzVh%wsP+%)`~EVH~AEh9ZvuS8_xU;<{L`L3= z%@w{lN^jmn6@lz8=cnX1x4iZT)UVK2ztiEvd6dIc^5R_nluQV)jgRF3it^aEV=;ew z2;G68FH4@|NVCQvx9($MJGuF=a{uzdz76mb(PPOW@O@DT!9*2t(89(YDd?#^G>@k>i5Q zHQD8zA8Ak#CE9xWABavz3U5sIA4yAD8mk zWYSQ>#_`EXdrJ}uiDf4yQn&`(wa2*{SdA1r2vpbB2G=@oetsrbtFEw&6mvcY+LO{u zBuy|xX8Gn|fo9p9!=L@4h<0Hf9-gxO8<{XV%=y&;mEYRdRw7&TJMQL|*>x-+VxDE) z+us6ekQN0Et@2BiF<0wLG|g1WS8OvplbFx5oB|;(WrYCP4rr>or;swFlzaD~p<~3Y zc>quPOoyd;QJF{WJE4tKEu8WjDVv13*bdEIU*Ii&vv0x8{;jLq+ zLdf%E=}*~juO8ejvS3YVK?|r(I5BL>cdPdHhZWzd^Qi-Zuw-u()yTeZy1ek@ZQyX( zGKdt*om*U7bvSxt{&uKes&?HZc<1-$?erihr1Q0-c~k}{WHgHprBk-%xw&)r!Qmm# zu=Wc>hNfD}F>W8qhWz@{PZ92X@!KKr|>t`+Re+NMR68AyQX`;SE66X zTEduLe-AU)vlqub2ZK;_K_wtqa}%6D8eYibqE)*qi?% zpD2_i-*d5&DbyDgug>(Ii|RGEv6`m8$=uSu<}y^M{btXDk%1us%S`3;xt1|(ZBS*d zU_&vtIb|>eIj~hPl(c#sQE=XE_t~|%a&nB{?XaaU*3k)8jQG0sP43qE?AtBhjNzk1cPH=lGDjlxL$?CUNJD^y@+DvrI*Gf7% z1$|g@rrkw5tVGk%R(|m{qz`*0d1z>8b{zgw%vhT8b&DA$5~SR82KpUqXwmB!qGe&G z3p(CiBtPp?eyg6TOv*vC%Wqh#py`<> zJ?xFN$aI$@OV`yijpVzB0|d7P8MI^}yv5E0n0fu6bp%uLI|b0m#ViSUPH$`!?UpBV zJ<7$>ocEZ`4}B=^L_Ezx|@s>w|f=e@^Ev%(DMT5 z@?h5@^d#Zn=#Q))4zKe@Vx5O3dfh##kAy%&vLE$aJho7p)Zp+^4To;)p|0JJUNCZz zV}af1R(8v5q?GRa_wQTq0QX=*1<9=$FnI~^FX`6!TnjMeBFNF{DWjYA>+WR1FUq-} zl9CEJ&Q}KV5u`$^aLIBQ_?h?l`HWUm4TT)KZYr)$PL(}8!REoCq1OIPb`P~oW9k~Z%9xo8;Y`pqzr>_B>&Ty(RM33+o0IH=C5~6}~x5K%tu|8Cs z(xv@6RLSlvZUzWfAnsm*Y@<={#xMVx6562$Q)5KVEq4ZuyHG%(;+aybNj=guM*LA?IFt zCMP2k-mwIcHtSFAr5Hu5sGP+xQOSl9B!Lz{Ltw^wS5@U)-OCtudbrtPO~tdHWhUc; zfx~L@_bvde+4g8g7f?rtEL(_QXJ%?unE3)2MC~RJZognR(V}!V0r4Y99Q*5_&PwY) z0R;E9QZCP~gMygzqL3`_MPZQHK;9#*Ly#9)H6}rG(7iYX-kd9;D_tsVt`WxLzYU)5 zKY#s#-I<);0nAi>#D{$CGVbOhAMtN&6lT`>_7*nw-h&4iFjYv_yX~`@4>AS!mHu|6 z{U8D)5wg4JGY7nGL{e}#z{ z0FyCXZh9NaVki{Zv(u@|-@s?LCswiz=3zY=DApGSRimt9e`P@K;uM0?&a0qW|Jg4l zG9IfCJfS0G%IowoF^gUZic{EqQKWMN+1lD_Vl1vRrwmxaEe0;jIi6TRTkvnME41sbL6`1t_=@@k7_!^(LPfp5+uw4J4LAG7Vhh zgit(a!a^KZ`kN2d$F@<)bpQU<>$nsmii-FTpaUrJDH|L>n7t^ki;TRD8lS$a&|Wh` zzjB38Z4OBVF%Agop7;mXl8|Yje1m$?WEGY`80|0AR^R6PI1SwSV8VkBpBSIYV(B|O zAF^9c4Gav-dE|y0xA|rRIcFT3tCWs`+BB<2|4R zL2gn`^iwU_+uap|Ihv@jFIVMZJz^~@d-H7sIj5{OnTxyos@hkp@ozMjNuNJ6=od$^ z5^pJ+&9_I#@mNzKcUf8SVN7OGcbm_gNL3l#Lh-CFp&7w}+w z3R{JRK%?g^in>}|SO|0=N>gYr-ze)AGPiMi_AKPAiq262`R2O&V6@6Q4A{7K9yvLc z{DE)Io5>Q(R|?8TYP%#4dTJcaQ4Ax@<>JqtVGCUx8C|`4wW+U9#g=T4?>P!fH+p(_ z5M~-JNj6HRKiYMayh)|`W6Zh{CsfkMdCK!YR4ojxGkm3Ve+GQZ#|{94vf3N%7qs1zdCbhW7NGk6!VP-Bb`00M^PrebBE}a><#dH=-jM)v__qVu%@89{WrT%I5PTU#r5f>WV|gM71J{5a#7 zA)COROOo!ya>E`N#T=eOpso;EGQSj$oxqv$#Kc6!Yz4|+9sEI%e5>;Ic6Jf&!i<)0 zqcXo-X|nkGy_aehA)Kg)W;PMUbp6~f@~5*?tvm@DjvWsnWEaa+y>5dCuh7|+7zkU0 zA~LUYfXAMfol+LX{F5R~Eh1yo+Ml!vIQ@gK?@Jg6k3bD}U_KfRUGMK#Ew>oyU8}^= zy=URbl$DwJ1?2yrKDUMVH-whjv=g1&4v#q=Km;P#2gBSB$q}<#J3HS|Y8ZfasO5&h z-rnJNDwx^NG~!}oQ`+W3O_zNcvpYaG;)~Loab{#@F3ZDIoP6eNW|rq6sFo(svIFZs zM{WpH4m+7Z*|qgj(FCVV5dU?HjN)h~zFG}`{W7x-Iz1DXu&}Vh(qFXj@NkV1!)u@f zeuRBO5TTQ$^qU?274X%Wt6xEU@H*dD5AKb+LJURodw;G94s&`&OC${hV-*BrUyjnf zGf*?_JncqGUt{9pzNl-c=(@c=tF~N{O-UG|nyW0)@DHVz1Sea7#;;UYa0}%GNKpmO zuUzIW39Bwiua(i_;$lMGjpUS+;;N68r9UY$Rs$;zgR4)%6B5WlTYd*R9BKoAV)y|q zKVWUm0>!@67e-rog9LRnHg(|#a_0<;$*4+_mYzQUwxnXky0=sEkxljoI!jvZN9L|7 z;oJ9et>R*1-`Me*Hk3@;YBuKczRta_k$DAhU8dJ1_ae*8cek=ewKmd!lcy^rCHHXCZm z!oq_5b?2OlgDFze@pN8kkTa8iU)J4Rwtl63wf=AtU9orR@@O{9_FzpDIdJXYZ@fG& z_E0}K-E3@uC#u72iPLgaHQ9M>m`!3#$m{G|^K#`AVmVC<#xQs~O#VG2Wiz~r7yv`7(dbn#!IRJ-XQp_i28Vk+!A7U z%(Zj6zv`bU7fW^qQfc+&kyp-!q3Z1daj@pm=mVk(rz|JHwC6NH34x9OS%0sur z(3e6RH9+Sz0>o8UO5PZ+PVeIw)q`@S`sz%&rrBac;FDV`*{X&?zMGL|Zw2Ti?0MyA zlK_~f_lNX(+?1^QvO=Mp>wbGDj9v%d$F1Fjus*?fh#ZH{Luoj+6;Ww8-Qf3eUdar0 zGBL?XL$?6y5%{tE)9UzXjp`p~v;;zjc)itSZwg~5bKikwj|j_Z(mF)RCh``@(K3RHM&pdex>DkF_2EBP;Erfy*oGa-t~H@JBgRFe1{@$cNhM2=SSi}ya> zHnqryY$7Bgihst)!$Sg)jps=>aZ6gn>y4>K7ImGdkmgMJ(`{Bdo~rWQzAs|e>yJMU zGHw(vfnT;Pu)M^{KNJw4 zfRGwW&`X!ZIqn&T3Fx!@;zDPF$zTELLRXS`rPO;G5SVv{Jpq&sL)VBW6P*7d7!&7o zuLxkG-Cu*2kCeuJ)pqkiHI+~77iqd~N(5g7+Ta9>>KoabqTKN|buDPhFsGvG*+43FpS+(-g7#<>0;fSuM z_DSIxcE!MfVJvFbW;bRq?IY$D!VKx{D*QvYK>7x&-VC@Un!nl3VmtLrR#~TYhEgYX+3^0^(lI_|6uKAgZWpU%vYeL|Hr>H z-NDjWZesBwPXPRPzA&v!gRzf$IR7EI<~B@?+lK#$w9P7lXNu@;L{JaNrlHL4=O>Pl zEFz!CJ_9P#_-6B^hn*FM9;=^Zv=i9C;d}>0Ig~jT!Ua_9X+X;mwy`rMbN68?h*-r> zPEH=KbuONy5iL%?~gc5~7{TUxZo`KPpiTpWg}z7j@kX0}?|X1nku5M{kc+f^4a(%_BqSy#quBWyx6|*N5gT&8 zs(;CTvUCa8&n0AJ{hmTW>-jI;*dxHM>QHM{+hulja%gXoh-Kh%C@*q8!jlDgeAs}O z&XZ+*GIcPg(Q$K<68$>vRmIXN?0x0fGjL%9>Ald3SA*iYVz7UHI7RL`It_v#=FOYE znhm8Wow=7Fx*uF-R}nN3lp&55b`X=I~uxTIl_!wUl z{FOnu-gX3l?Ij>5&*uMnB~w!02JQ{ zl{DI}1X#+jh{9FYUce6sTy`b#OZJVFy$P#J1icv0KEAK7@6Orqum+oH@BR9zQwXaa zb}$=Ud5hBPWyU%??)yx5O*VILn_>nG&r=OcGfV%qbu8b|meWRDTMJk5tKkm!(~ zklYSg!ft!gz;<*@y#_G|nCIK5WkRVaVE533Ii;hew0s=8>mRGWC=}Vnqm--4*Z3u3 zlg<0W1D?@*>t0g#!rtbT{(o$qX*oI4wv)g&KU&m2ts@p)N|It@W8<4ThKe+FT#p7; zCcJt}jaiLn^6~=-!e9Ud2#e!N3YIu<$uWX01tbV+A&V_{YH*HT)MLAz9hHC}9ovTr zm9FEMhzQz-%M16UOKt7~e4fc@xbF@dqi!wVfB)PJ|t1pZ{O z7j)THgC`GMgACo>ZEXZFrT(pEM#BdS$Do|tsRsx=w@`8r6j{{MCknVKd!Te;k7!vs zBmk0jDU$=J4%cchJFV*SC?r0ifqC|sm$izak&$CVH{iq&y+zcY+dELVu)028ts^|36wZBmcVA@*X7*V-hK;WZRnVJdwZL2bu6!dP&A;f zPQbNRwgcV%76|m5a*HEzdwUK^CrKseCOMOa0jkcAju)0&1PD}brO8C*jeD?U_=7I zVosiD9l9&7rEq$(1$h*_yVxC!G9{oe9GI>Yl38@`A?J!sK}3vmt}x3Hft!MFcXoF6 z7f=cFueq%!EqbgODjor!1rt3141POxkQ7Vh2|hkPj)S037I|!R#j%O?9JH&D1cih| zc{73m*mmaUjcPp@r7A-vw+baS#k$m=$Z!zvCWv}*bPqhSD>b^r1J zguS)V3OS(*uxb4SF=F04*x%nDJgkG7jdfE4G5boWo*5}iKHb^Y)`L|l&98Gx8y_D2 z(#F#s#ZWU7_TG31+~0C~RTAK&YIbA#CdLo=_6vV{YBfx}`^_pKD@ee8fsg!zQ6wk! zHwe=XH^)JW3jiFttjd((qiy=NlQ0u)5-6dh^%Oz(!5K?UM%1S~efpH|7;+NjeBe9n zh1CN_?z;h9YY!QgXjCE|w2KE}Y^e1Hf6(zI%uB{qrXN+~<9O|}W*lWxg#CoPyu1Jr zIV?x?+)U;LO*%SD}G+5_DFSp3{psA*YP66OQppKEy)YNs6&*3=jZFY%z1!4 ze}mps=Z+Ec-G|~b=&5^LnJ|g-z0+SlQjnJZs~L-?*HY$wXjNaDv)7;pg}&J_n&Nq2 zV4&TsT5Z9>?rvW^x8*1&N@7QT{6`9H^llt6ZvkW9yx(8`&ARkhNwCSH82K(2A2X0}`h2hN`TqMouxYrpD50!JF@d1v#bCJj z;@k0?oE>dHh-IXl_S~9&r;t28G2wVnU0ht;C67&0pze)@h1JJ@9s7I`RS{P6K9Wrn zedpW0N9V{c=4i0~w$>q3RzYDLYIy-4DLv+o)dcurA>L;R7a;2Hr(JQ`u|b2fv2!mubO9na`nH zvy{Z|tdLv@w5JcreX$IzZE!7NcMX|v8UlFbyOsp$41Uj2lxK#Xp8oG~m#mkRJv%%5 z2iZ(zowugmcq-I%BuD>otki#zjpXgE_@QC50{Xuj5p8OQ5(NY=c-4M^B6v+vIh-d; zpe4=et$KP-z}(y`HsD_BQ;Qt1}s!1D21@B{$iyY-qXL1M2#u0?DY>NX`kD5g|W8pAOB6{`BcC40!jzFX#(G zGh78{3KW_2>NIrf-PAB?e!__zvr-RGLx~}uCCRJrNrtd$pm9m0_bhPnCpaCDpd*L!I|0#3 zD7^iJGKX_t6W~z)N>-Ppn z7~Uj>T0CT9>oNK31)u|j>8Yif-V*DLEJN3>+my<%-^N59M;&_sADCabR&9_Fs3)AS5{V58M;SD z!p`rosKK_hw6r}w16gW#2GCwiZR&Tx_c16%H)!ImN~|Dy=gxFx1ena8!VCyfQYbX# zG+=i4STxK0pm_m!zsJW%2HJ}hjI5mv;`IFdI-z@QBZJIFwhv%5Vh;B&NvV-Rq9q`9x}HVo(v z@x6~aN>FVS-h#Z9Ob;*S6~%RJv5B>_i zlJ{Vpt#Md3?CHSefTsb|9PpiAV2p+%ANi^*vgxHV3JSl0Q`3xO2)J}^+xJLseF zrtcSD!fdtEI$4oU^?Mj@)4-<=s)`7RjRiHKrHK!Dtj0xPRX=l@8t&m7jG|5`{Yl~{ z0nYS0Z0j>-l(#@!O6ndP8v$&nKeVNhQr%p*Ga6VKK!3mibE8x&AO=4(WoNqO!(pRe zEN73O>q=w3gFOZ5bQU(;q&MRRRCM`=_pFcWn|9`m=CBZ8XK>&W%AQ+t|iyTF?^K*`rz#0v& zEl7Dzvwlo4ajB{afRpr)G zLXjO=a2!jplkqei9p-gBy|l*9;DV_3NA;;cLNv633)e)f3)pCB-3^mM*CfaqRSRDv zc#f`upl7$~EhdH$u^+Y3EGHiQBXa!0U?|jXX}p>Pu)y`yxzGU=1K(V=iE>Wge(B<G zAUvOJnwbPb1xhP%a^i;V083>+(4PkcN|{^ttAr^G1N_>VfCd9V*93iL&J^&~*1>_F zWcpK-bsVOg?gU_zFLqNdS)1g7bQ`@WiyKZDpakF`kbGbn^UuzHh^h#n0fD?Ispkb0 z0E#hMi`FfDTS7w?aAAfbC*Xum+vOpIA>h{@Ak9(o2>|b?MWAO4;7%V<2j8bHy^eYS zQ8BUY-{jVCsL|fu2V_3DYuovDBGe%caN(hJ?NB($Lj-NIQr+qDRtQdy2x<$4FAUxe z-TZ)$I>@pGBue|WH>QAs;so8fFOIvtU)XFc$$Uyk2!{Fa^V&QN6v04<-GWL(6vx4U z2j_$@2tgc^iF^RE&-ZBZZUdAH91HvcSGjCdFHWzock0!uS^*LW z2o4s5vI5OOsou>I;uFr@iuD9vQB(kh|LSxg_JMNdr}BrGno9qeqyk&N^jjX%u2gRRF5P2E z1vi~@a!_-Z3fn;!27yw%qsb(`n8#?TI?IT}3x}%_PQ`{k@G)fQ^{-gq;w70TyS+Ab zaFXweDU=N{69}^T?b*Gt*2`0raer^a?+{&I_Eh{8kwv1igb}Z-`18*n-;qCY_;EY0 zx8arLV(uLb=fI`F(p}XMfsaqZ3rmzxZyEcN54YDkIg;9lACoS=mr-5m5>j6#!ma!W zy^bR0T|lxb|Bs?Eeylx`X30kcf`L~4MG^Wo60hXy;zh)tGb~R~?7c#b7qn>)m0lT= z72CeDw}`sc6Uk)CV{H&_x&7!4ty{!viIm08-iL~rSX=B{PsN|t!}oc{T&`M5B~9_IuL?$;dp`2^Ee>Se@Y}Zvz8>!y@Xf1g zeAyBRX#~05a86Wx?E4^Am;K_kfNmo(_l4UvDdq3LitQC4j;Y6m*y?69$$7mwxc%V^djDF z#w5NQr|%^bW7rFT1P{m5HUFFA3N*7;4sCQq3N`D-YC zH)xl7c44sCW;1SiCo@fu--dRj!jmAkyZmQ*1$8Xr<6I%NPuoc2Wo_8VeEpCff{1(k z)$U#VxqicH?AL{bTn75a*nY~dpSX-!9MsLcO&Rh-r3ekw0KdM z?XWQUphM?c&ddyx_n%ztG6tQ*B*lD5H3?I7M_WlTHJp$eORtoYSNaS5k|{r@C7I8! z%gBqXW*X}@Wd?P%1%m2ONmrPRm->c0WzxC&YnuUH&{g`n5GwvO{VxsK7JB9H!%72Y*#Y6Sh z&<}6O%}q0lvms5@Zt~wklY`&rd)rwL%IrUipnG9r5 zjdGg@yc5(iI*fn9zu3QB?vW*gG}TUUXk1P)7`tQ2{k|%I_weDvoYkSq*9QRs1!=4g zv6SR~|EP9*teUHKUCnIpX=s*JPG8j@V@&*d@V>U?XDUE2pkZBfFx^_rdm_oNGJmy%$gNOp;LH zs^!V0m{BlJM_XEzp_~87mKR&iEPq<37B_FeG+@*!Db;L*fGeSdino2P=L7dde>T+%wzxqBZdp5>fo%$&#AV*@lI#nmxBd**V6Sp zsp|Qw)69t8Fj4K^62d-_w@vLLaQT*thE!3%nr&RXjUc9)u19nJks1OAbcBS&ZGCO* zg(@zG)9*`m;5Y-6_VF9gV;xC1$ChCkGy!yN@#e-?6 z5oWc&2j|u{`yFhpAJ6Nr*qdlqXqU&jQ~bKt`H4K`@Z+Y)(-XmOjg5r1!XC!mOne0l z6|3K2NszyGKbd}^O~(8XVmSVhc~d7(O-ku$Lt&{j^!jS&*=_oNr5it|DiI`)$Mjw0 z?FPwp;&s&S30!VJ(9zv`6`HQM8vg=$TifD%_k2eN5pf!)(@EUtt6`c9yZ+{o?VGkF z*XM6&NFJBjdkmH%n)4a9vQ^H#J(g;vo2i3kg_fU%)7k^jCLjKtJrgasn5R{y68dh_ z=#-9l}^!yh!k8-%W!IPZv?tCIn-(OG8TEO-qklL{Q) z8aghicG;r5rcY+w`ee;}zjEt2Sie^&OcpIPQzJ-7{5WN_Nejs&AsVOK7-}x&V7|*a zu1hYO{z&pwuUKMOeelnICJ@wn%YKyQ=x4ur5;}G9&DTR-DZaPFk9fsxXG)Nt^Ov7~ z@pMUi0{Z3J*bM~*1w{6Jy}0;C=J_VcS1o!XEtWmEG!s6%iu6f-?PPK$0MM3qQ-njCP zmrj<7L$;^(Cc!Ju1${7*-0z+D-iTb~bZN#LAUi*%9pvkFz01My-Pjc-%!;$xDx2f7 zmrJZhf-4o97;39s4~qQ<63UxYl8aOI{5NcJb43S3CB^MIf7 z-);A}PJ4+V%4fEAx$obc`J*#;(f?_aTxnAb?)#cW7OFens9nz`r}VaJKEWF|R8`*G zGp+Haq+zmwd)M!O$1%@QznW`3RSJ&hX>SHQ`oyj$@cjJa?UJ51_)Yq8N|Aq&<>%?7 zhVI#;`TGSU-fc_VdWT-P)+ZJIk%Q$i&xgb|6^ighr%%sgZO>#QCf(BYmNC($tAYdL zIMS|>|L z3Su@NNhxJGc>n$yK{#za9sVs*>F!suSrezHvt7^F<7IsAb3}Rwch8o+Tu83=XLauu z-1O{rcRL!a=vX5N3+`U>`Qb~XapIkDekFvrr6Zh-IxwXF$*b&NM%H@MA~v4y?@oU) zmo$>z6i*pC&G>%R&n#~s>IdU%+C{yXzEFQc-Jxq;%eQD zeBP~#P}LZ7`qgkm;E}wTPI%nhDWD*qGx1ax|LU3jj7Z&y#E*nQQ!a*(HrajC+l50- zCM8v-!hr#!TKoTmZTIS5FL}=13s{PHpA$lpBecfQ|2in^UhQ0F>H901cm+RG!t{IA zrsS%Et)h7ZZktJ6ww5#9Pj=+J8~CKoj9J!8bIf>sRjz)^I79aoL7H9{i7s1Dan@RD z=QQ%W!uMOUHzjzT|7b|*amR=M>a3q;&2xo2#Le`cYIu)y?E>26>X%fJ!9TGC{7V)o zu{g^5Z3f~>+_yu&Qdp<%O+Km&Be!GjElRrfe^K`qZc%q#-0#rc-Q7wk-Q6vM zq=X34-JL@tH7E@tB_-V{DWTFGB1nfYG-vZX?|Yr+{0YbFy1j0g@i%*RtiATyYkfcd zh6L(LURFDW4>IqLm(`fPCg*p2x3GyTA9xaPw=ue2h>d8Gyc|>z*L8 zFwyWJyM5)>iwxp>R$GU9e(ZY$hW%vj+^RsIbB&P2VL~KED@Y>ZDO#NNyA&8oBGnR- znzIBMXZnMI^&b)@*AB5}`*K(FU!AgFmOaS*lQMekyk+r`mwZp~m#$~Kv+B_y8LND> zA0##yR0wfN)2x;+tmeEfnbq}0gfrh$MjMnod~>dA4z6BAf{DGlJ8rPqh`;AJms&}Z zl#Sdz{J4g5NScv#(w8V6(7-~b6k84&wZqRT?mOB-p0zwy9q6a}#4KBs;6%qr83LZa zIlAgzYhfdaIXVUjIW#Ri{wYE9{^z%~*-puH)zB^PcsQze@VseAONEdiS9tSN&zKnm z$i-GV0%V+W630lRZz7CUiyBfaE1l4XQXy88^P}>L>IXhSDI}QzCZd7IiMqyL4{P25U%ZP#`C_?7l`$Ah7A4g+G zq=MzwZ5~dCJ5@ye2~l}WkRnB0uChww3x--BUP_mRb!;$?1U-g7%LQn37Pahijzs|T3i#^9m&3+_A`?p!bcX#Y`x;5e%Qd#JGjxf zm%zY$v(!lbHx)_iiB?l!V4#yB=*Jo2EsMRo?LIT#uZE6`u{NpvFRB9YNPn6 zoaUul#!~T+DNE>2R(O!>&5XO1gnKL^g*#_G;8+nh&|qCpN|N zR2X5^Vi0yk=SgM9Ct~F9tzP`7aMk{P3yE5ZZRSS0i#(U&0j0wFa#XC!D}zy+aOKcsXhN`L6^FN69jtHidb8 zYr<=7wqr1GB`ImfvM0fYfx#%cjfWzZMuaUeh^?-g0;fS0A#N-N<>=bx3iogs1qD!XJJjE_XdzXvf3V(Ujty5)>Y88UsSJ2Ll_!HgXmE z8+k-%$lre_YIC`cOZ#4pZg*|1eF&^UnTkPL_P`OO_;$`qx8#b?dk9yDONPbcl25uPh43>`xM=18(qL@cTBu zn2DP(-Ozh@N7!7bO~GD-y>{)HO&&{n$>Iv!9ObJ`S4`&j_mJPgMIBt9Bl=M#7t;74 z($k$|$;=o!fv81rjU@tJAxl`#}R2XsZ_bk%xE{e6pG#m&~B`&!6F@|^i2Z>5=&)D*6PakKxE|)XD>D6|U z@bShSzG-I#{>^d+HAWFn1a}MB&#FLt--o8&8?$eDQtOvw&iRYqn&c=b!Y#CdqLzYr zdWY#|`7hhhBmL2iTD}g_@7<@LV$#Ei{4=ydPo$3xl5Cf)@_*NFPG-6DISjK+GJj_D zS~DzN>iBVVIk(R@$vT2zbn`k8y^k~)86DF$)vxgL<*T-f9QI7G5-{{Ke~Y#bWB47} z|Kg^=MF?pb2pwF#7gR;ti@f>i64zulwETJhka)8+SJm zm{+9;qrY&&;gk2|3rgc@cN1{`yRf#9c9q2ruHcY94%sq%-mOB06*#wK5niDOr84A? zpmwhXc#8P_(QC+%VU5iUhq1Onbrqk_pYb#N%A^ox6(84w#$=4c@^sA7r7clRwepjB z)~l$^{m^BkdoyR@YC(Va=4-ryDR!G@aP@enF3mhvr8mFkXXK|^442PMrga)Bh;(_{ z4b4V4ROaTpfBqO*?y}tO-GdA@yxN?&tbliPdP||8BWAy4@_Kry)NP%ElGhQ{JMkI~ z!kpNGUhb!+B#{_=_4%;}ov@J_nKYj(b1}lzNDG23XmLx%jK4^z;-;>4Iyrw-C z@2E=k{#>;HZ+q=Rnp_tKmk1>L?hAZ|K&RZr@&gI5mZGGtB88fb|MlI>wcTDMqaPxM z2x@k&^3p9v_f`Dv3Ky)UvP4Tf7ZYT)C-yjuug-pXW-jRkf!t;s5E4W_A7+tt8wq_8 zm~;>M{jsGsd;$UTyQ6D~M;HroF=FYEb8Y~|r9zG*1a^NX@W^9v)_Qy~I&ebc4C@fF zyC=P_UBK_fgDA!B4ssqmsnVBLq*`j@sef9~nK;-sgx7p!|E?5(RhoRh)&(iXtfGA0 zDCOp>jO>j--^qT`e(cyyMM7+|y_( zhpdm}ko?*hH-XrhD%HNTDjAa3b~*%t`e}ORKx8(j*ss?#wxHlbQ$@De zklUHh0{TFh(_D)Fq{VPN*0Wh&ywi_K@PNzV_-J*-&nq3NPT){uMRMMQTmY9W4HF@i z?Rth-kBcOeS01-D_}rhHS*7=c@M^WdKAU#Eo8X=FOC!-IF)BTWm;{l$Ht8l`KOYXf ziLo(7h1BGx4T?D0<%K-NaFbELJZ(-1IGkR&(D@qDec@ zuph9sVP%AaGvce}s|rduZmS>hAlj50`LV}clH2V9cp?>lTG>hvpE()t8zeG*wF!34 z_!U#V>QvLUFeS-;z1{_Zc&Zn|1Atjn!{-gS5`rr4;wlrKuJ}JQH|;mq-Px+0N^2Xm zv0T{utaO)n5SVHb+Z*7)^JEaOFW`+3M4E;|5>GV-`cckfsON>2@L%bx zTn(HgERGI~r5h>;{o8KxqwO3ge;$?IPMbrg`Xe;6_6Z46td(?jPm+ggiu-KuhNe$1 zHs)bW{^Q0$gOMX0a=piA$@de!A@tq)eu&0LdE~kus@}glm4o=Vv?E?lW}Wu^gnl}N zj$Pad#UBb$FTC;HQXgMDy-O@oT@lxxD5vqcm%iCM(S?Xmt-W4 zXS~}w5U%!E-^i5Vuerp&I*T*eV?X*t0KvDhZC!c6Ye<-y-}q{6auyX={6o8)`&9}DRl&`InKS=V+oa4L9)r==(Z7c&N z{_H6hguZlYj8^Go>%!Ia-$uBLQTaM=hQNw8q0S2=|4pKNo-wIU^^x?>l9&%|v;Ovee{Qjmj9cUYCOMv|9V(>TPZw8o z6i>~&Cs`7`wD*%zr(fE@}|v-PsUUybv;4t>wc(h=)bQJ!f2 zXlug%@z*k>i3pmM67Kko#O>nUZ1Y0sSnAaRH5=EKilXipYT{Jp;kK4n?!pSIp>iqA zQ*(D}a#Jw>Lc&cncnEjlN%fvPo%6SFgM{AW_$uYt>Gk#kiIYBDUu~%1kf8xEAtUkk zTl9zo3K*&~A!;e@TNvfv1vBla$LxX&->+6S3}6&}3&zl{DP|d_khAtgoS<_Qj{ryZ z&B|{#e`eeAbLg)^Zp-i;f8J^M#UX;KTz!7?R*J2 zfx7W@H(R%$QYz6Dn zpWiyRL~Q~_vV-pgEeDxkR+pzd1&J#1KShF-?gx4N)mVUK35RzSGBJ=Xp@cxWPbI?7 zeH>C-OkJ)CGMk!79t}%7eUB9|uJM~u(B>ey96dH$fjH~5P%O8}7vz7U zbT@eTxzYLYi&wwA;#!-d@1^Z#4`~xTc}dNm(8|;2-R_moP6C}A=VaDLdYjr@_G%oL znXwHTsb;qO4kx@S2{t0mVnkdRR6vU6nAX2g&xs8kv5<+$+yhpD&XTB%^p8bBQpT)J z%-`U(ZvPrL!PR6xYwx5n?r}zIIH;iJHISjn2`Qwzj|*$4Zrph;LBR^CIr`9M!6kMs zZLOF0d2f_MfIwSugv&4@5%hsg@anKFjr~3m^b|$O*qYMOaHm&(RB5W_XphaW+jMTIL`8el^teVC%xKHeVgmA zQWpKo1jAFE5e!O%L&z)(f?XtH?IDU7_9>x`m0z3M6!e19dD+ljht{Jk(40dUY8vNc zdt#I8gm-Dm4v3pl=;ywZ;PSc2ZXSYw5}y+J64t#;+Fy|YgsSS%sp+!QOk*p zS0!M$eKqk5q0X~9Ewst{mGViSx~+-C3(qOVUAEXid0@1^pEf*PXiOO1XLy~NG>xJD z0M4QPN@(0jYx8qYXYkhp54v4N_R9S&iPOR+cdnw&Rr+(ta88EFC3RSBIFoi3|KW6h zstwe|jIjB?xd43c99Bb)(kUE$u7iW6Q<&xVn2z^Nw;4oMg3GjPJYwF%G)dAfC5P{A zHyNCtt^{jRhjhg!!_tCo1__!!ErcqkFbglSP+xmtnM_J$s@9Nluw*8V2$(gL1ty=uvJh#5Hx?g! zf)n?jqCMe)J0L>X`Hgnktoi*UJ@u%sS8u|aZl&$(;VH}x!I#fOHdV`dL*|_sExs!x z1ey8A?!U)?)XZl27$?$2U6qQj165}3pUxR<#WT0zE3*>fhV$3CcZ;oUCF(ACcXp8> z!j*(Z>4*udrUW`RB&hmB0UYfqG z2OVjgz^onzci_H!Q{Uv}GA(&kP^fDoJyCIp<hB(E1fSbgSHjq zTxLR?e2&X?cz1qp#*>Q0a%~~(6;vn+EhH(VPV_)K!3AmSbEMATo@29})uMGa`{%3) zAu?REl-F)`?QvIM7YEjN-JjPUf39OdMxXJDd9S8CHFjy*#e~PL6Q*C;ZwkY9JPxRh zLjT@wNj!ja7+U6;c;islF8;~e4O`LXGyhzW-rHwYB0@y7*;#*=Gb=kCqhpURcN1J% z}!8Za(Db1nz-Ovd>khP84`dHDVj~LZTBZ-LURo?l_t1Kn1+dLQM4~^gdibKWUIKI zdVN;Xd5feBC*R|KJgkr*9z|5@p)3u3lgnB_>u~k~<;S9}Xvy;ZZjS+KUk1%9tq;2XV9&00)Y&DZ28)AhS1$TxT>^NgF&uc^15WPliT>OtAqbB zPmMt~GBU#5p-I#fmb25p=ZkSR7?n0Ej;7HD!z>gY!yWCb6@BNd#SrbyDv?sP!)g6x zaYx;pMV@q7Y4wP;Wt$M_5ED4{5BVK?!r$6@fpgyXn-vf^fN--BQ$D?3nmK;^;p3eA z8_i%^6&o$t0UE7(qnp;RbUq}Ua%(MdvX$SGQ67*WP(dzE?^|S_v%c#7xu z+|z%ReRc5_Chjm?$a78=vmRkcNuU&jz<;g$k3KZ8$C%nL%J< zh%yEJHP){g%a3^T8dWurub1!mQ^s?VA?%m8uU$TJ;w!8FQN|AVN7MD4S?)-Nymbx9 z_XR18=qC3ACZtsKINH19=-(pQ&AxKwQATKI(^CW>{F*h<@9}{33qd7}m4QMI7O)}I zGR5jOILVSHoi&G(%uMa$8Z~kAxKd!#5sr1-|D04aybh3dJo~aylbm^S7veL*e*XK4dFzjV;rpcGz80oX4~m zYS2NYbr^hame*RxU4bzu`Ce#QUHr`QU&kJY^f8Xsqj{Ih_!60~D_3v${%j#z-~Ki1 zCw|d zmv7VIXCPdHkZiC|9k7KUO;`}4(R;kLO3%pG(6hN1&R|+)eN>Okodv72bLx<~kzz-& zvtJ)L-h&Xo~N zEv3eC>?qTHjC_Yl@0H8uE4`17?9Yt|isPhZ>u0iJT0>1N96FdeNtI{FhI?VH8}!-7lIF*n0daAxtZ^9CWsizjWBV0 zEm;~0EDad(Q564lbR)w<5e)I7C#{X>+9{AAoe#O_^*l1H46(8N83^u57kI%q-9%iE zU)Ob{iVtt~-*11m0#{4)>}f+E&qYlH-U6Swv}mKFE2+;}e8Zg^xaq=;e?De*c5J0123Y<`npAuXtT7l&*hKtJ%S-*=%#Cn{nk%l>TsX!#o=Md;MSLq zaXna)$dKQiG(Nh!7>E^e^r@Wq3A!PQESI;diujNYRq2EmnjaYp#~IG5NV;=Isc4iG z{BbU%grCv< z6EvUWq6d>B%G$xXAbq(0pmktE&wog>%XzaYwjX5q4E1{umTC?{U+yJ#8%+ITG}aON zOm=l`&UUHm+tCs4X=ENE?cUdq$2U{Y3$#*4iA2g(k!WqaW)~CGN5dL;(LOVYfuDsk zAo&G2JPafgqL^_k;+XD8;bzdp7N~voz zSf(2lPg|O|i;oady(KPaxCWH*kEq;&W7QJeun6KIc2~;{dI|W3iP>KQ?Mj z0=r_wWkPJ_J<0=|Lt5@H{SrKZ(Sd^_cMISos~)`Nf7{MQm%_}TTcK#A%kRgmC{v*? z3*7d(4@qsg52X58M=uRMY+7d?{4>9JeSdW&ajI6%0C}_n+?0XW=dbrxtOUGsr`YlMn zXAjkz<~^#jVD+{QFg`F}!=O1J1$bd*9D~hvYK1 z2+dGJq0GKbu0Jf~4tEvi?)s&uZf~ZD$;|h{$%fie)ELvEgs*N7*Q&sDFs8f30vEb7 zM@wz6mHh*lE)|BSN0>iFB~tWi!1rn#2=#*-i5K@R8i>5B26U*pyS7K?IZerl)aoDJ=#vf7Z8 zh}1nP>G8d!%N(9ka6+%GUuo74ywuwF@gM{S44o4Pbu4hgShbWv*L^@S^6OtK@ugL! z4zaieeZTe|If(`z_Z8sS^W|P#fBFpPv~{+NgoFkZ>GyKbw8kKtV6*l62Y=axGR0J-zVOJ<3dw$O2EXj~1 z?d?@D-HL&s@5H4fh+e-*i|R)fe*10~HGotD98;@v_R6o8{+hKDZygf&cKFGc#0XA) z7*R`_h4u|{QDPgRrVy75u9?MvsL!1H7hO-6Ca;}NPXz-u+LY|!7YOZxtce9j>A+??HLjNI@!iO}ClY|rQbiv^~ z$t;5)SB(HYlp-82*VQ{*#`d@Fo2As6vHR=%dSlYQuA^EB_(2|jlrw>YXSaHk{P8_1 zp_C@E4$xnri#n1Zb->w->4zLwnALHt6?>o#0 zdvPALCAMMy6`Ba|_=0Xs_oXtrLSsMil3=^w^9FOdmy^Rx^P3$3h-)94v?DZJmitw zDk;ez?ApS{nI$OEUm^;%r1}dJhgxLqMb~T#+6j z=^2MO3p1;4Dj7685L$m)mM#nUe&_&|NI9Vaz?0i2KymAD3QZ<(M6#e_E1fwyZkD)Oe~09+&5#df}ph~r(_U!{GD2B z3=ckMQF#1jlZLU0jGTeV-tJ_DGpvwFd*U(7!&c_}1@5pjT}Di@i`Z0`4;AK5M;x!s z3txQFMc-M7U#ntAAdQA(#UXLzh@$#%uzwam>Uu7uV?UWr>W|bKx_eMQ#+vpPNR z@o(fpBM1sKaq|%9un+k=e^^aik*`!K(Wf`7aCp{?1tAgmF3MPVtP9cwZY2-)&Ar1# zd5a;;yDy=(jNoO`X!N*W)3_aliSeMsWLH~!zMdh?*eOd7~hh8v>5yZhRp>=zfi)B5oDg1hohz>5D2zhE(AnD?qXo zO>1SL?1MguPSW2i|h|5nWBsds?whZyqU-+kn8UC zrVrC7pNyGnb6I?C6Bc1%EYwtBV8^dG!)sDGrc*LyT(hP3HK^9@W>s}@OUD*KdQ3Pg z@Ih>aAkHi}#B{4Y84_ga5>g1+nwzo-ftLTCtFf}Gp4_f(*jTfG^XamGeCH>05TkK6HO(9Qwd>IPXB}hLZ}-ZGFcmSw6~IJy z(TOoJ>=Ya?s*qa*^OsqAw0SDXApEK#iw5;EF3r#5)4~wnQ%YCpNmme30H!GT#ucYH z8=?vr>4Iec><|c4izob5BOx6|A!EfM%Ph{X%JCtUst7qo2>;3uAy_iC6J;L;kyl_hhiYG{%d)tn3V@G`*u@SiWbUygV^G0y&IIeFj5vAp;<_q> zK*6bAacr||JQm8X4lWG?M))Ao>Sk3br4Osdv5CFQiQw-ia*1mQ85wf#F0Ql1QSKG3Pt2GykNhp`W8ieYZhA2s~y{GtoGnx1S*8) z-UzsG=Beido-rZQ9KAon;!W!0kcmJ^S81JVEAvz{9SA}C2F5WXc$GwaB>i+^l*71R z*iVQv_2fV1wxFsm=Hq}v*ofsRs)K}-e^?s`dOQC`6=8n@vn;)lq_?^8OLUi~3>cSr z{6G-}3}7c@lD$taYD5*hCma8keSU#l?pRVonhu4se^-H&JDl2)Tj$K?q(pFLO?^34 z;Ktxx+!B-@>(Q8jmdeGBkDlLCr^2ur<|9C86Y_hla>4Qh)jh{A9PXoL`7Hk3-^U=& zf@SDJ&edo>Hy9jJO`WiL2H7IyG*2%?Zx!hV(+DF5dR->|=87hE1QG$wIf~WCg)*{> zZ_ya1U16uBrh;gf%~*bbxKXc~7hUQJy6I~iXm5s1{@ef9B5^Gd6fJ$Hnb$ddo_ zkE~tG5wxJh|Z8^$KQ9pKY$@*~b`I}QeyFEl4G~dTr z$fT1}mqZ*Ea&25OEi5n0k2r<78dFA1?A7x|E>7@0sPEc^f`$xG5}a4;h{}WhhU8F6 zR#;>gBG;M9H*{DF9K}FD30*1C>C`sL@9rvKalj^a0Zjim=q?=_Hs`N@qi1~g^I@DEDzQ?%5F`P@26eIaNB4QMdr=4`@Um<*%n%UE;o{7$Zk5o|9q zXlW|AktrHo$$2s?O@qnyS$>M5{~r~_KDRa+>2lsG5998fPtKl?EH_wBQ$I|7y}Hc3 zLq&yGY6x_c2BQt9GF~A@Kt)B(wncLvzh^NE@1=}94JfFbn(;G{6qqxmXS931{4j1$ z-B3vjTONxpmY=dvfYZf*A87WSw7IZaAUyPFQeXWC(Aucvxg-B`RsYVzw`Z8qXn;#@?A?=V0Tj;CdIHldF2W<;7S!XR)m3_u$#yZ4*I zmT~N?EC|-l=U@g=9aq+;6N9SZOH_=7-CzB@WW4Z1@<+S5BMCR}pO%VVM&F)wuKwrM zDBoSAXEkaCtiS<=2qzNi*v5@hCS(RNi6nY4T9Ev|2y8nwz;5aWBskvS`&)DG$WlXD zuScu?pFg|VfI6Xc@LP2p`{AsYsDd#X|Z7mt1^t3Hb78o+cQ4Q4#*1h93m3@ z;5d7uMTioejPgBMy!6`72KXzP z@hK!M^sx=z?Or;((|MV0y5NnU4~0k8C>w-U26dXjt{*VfmFdb`d@*oDPY_vmJ^Msc zy{u!JPDw5!-xTw6ROf#l+703HMQJs6)s;So`R_yD>hKEi^MT0$?N!KkP3sY` z<9Tf80C=N-2j8Z->OtgxdH~0?LYS}k;^(CB2nH8HB+}a>U0#9nyUjc5NZ=J0JB?7M)ZT&G+9hFu@@SKN2`nkd$Ii#5HvYG%pR07TGj)ya z?Wni<)1k5yb>1AAA8-Ei)MuzZWrk$z;Gs`X9=j@WmL2o+A5C=;4wgdg>FobBl?yaA zF$BTP?&4c~I7b6X5Lyg-wg7ygKZc_r{ZC?1b~m98yX6Hp$_ussbgK4a_SkGrMoHv7Xu@N+ zHUcpMkaThs4AdR$PsvwCb=Y-L+K`CFOWbt8xEc|4h07Bkl0PXj7oF}N0t)>iVEql1 zQfc1+-Ke`Nd!P<62y{6E%oc&B^*|EcD)$`|BPLoHvx;twm`s5jJ~E&b(n-+*ofJG> z=hEjJ-pSD3dGMUOuuky1ZRWIvF~Etr4P$CwW-D}ggk1r0XwP;7z~Y<~5I0bP7BySO zp#Gm}4oaPQ{}<@4lli|ucXsvv<;N3${$GAP<_oZM{I~WW{2j^hzxC3+{~vz5#sBi- z|9|>%FKfi43#@?(mt-)lkdsNWA1Hz*&fQHa{;w*8qe?Qf3Z@s^$I<3stnv%z(35Ub zc-{lLTM|@L102esrl#_y7+kd?z9cTg}O3nlU-*K7*ZOH?N6bu)Wt3&qzGof>dn>d&w%q$>6k~?t<%3u zegR;13vow43r=S{dt=x=ZK>5_N}p~8pJHOpRAGAwK^L$GbqbG3PZ3rmMxE9p(nN7^ z%jKsQgc*v*BR1Iy8U!Voi^rp^ICOk4|3Z#F8>6DMVaXpPB5@KlkfM5*)rp=}nm!sk zPYAVAS=~l<8v8AK*m7vp@8DC_nl@=-VibD=zde5ff*ejw)w1V{b^95>azaHRiA7U} zu9~XEL9)-nVX3->lVg&E3Isi^4(!C3=wbM2?nIGPRbd6=PTl)28CopjnN{!=_}I5e zkr6NzhW>ctM-q>#l-0$7U<%UPE|PtxOVAy}uar@Jz|zMYLzrH7Sp%Yh!ayChEzh_> zJ0Bo+m9?~@HB0Q_XGD=inyN8x3NCBv+uNyuGC0;q3lq>50@6B}nVI?7dF&=G)GR5H zmSw1Q^YK7Xr{~SHhMSgYI{_df5)DNAo=vJTD3R>rP7VIVmj9}~G}L3uzX9}P2Y{f- z;@uKvBqGpH{02ZjynrqTL^d1z>7czRY{l}=z#uWCf8&Z}1o*o|EbeA{_H8X?xu@CkHw<{BNT0I@IDd&L(Gu(4z+(A0^`%YhX*-3To18W~X)f4Dt+6f6SDq=T+N z6ovBJZ-ii=n~F>8`LsRlDu0JEDL9 zT0Y4<8zUp8(#w}GACV|PmL~*2tB^EHp3j*wVp2SzP?jG%0czP+Ko{*%o7%i-F?^YZ zF~gR=@K+IF9|C6DKOlHHE&1z(@B>iB#-U9GJqa`O!vIgrd}fUaG~e=dPf`^7;}r^x z)&YSJ9Wf;}6`;hM_3r@J3;w7Vhx2ylqdIHA_1}`x^77ut!301n+lB@GgU12ta$)!W z^uyNWZ!ru|?VA>C074CnlPP$w-Ek(+<+1?sxj=xH0cg%ffbr57lhRS+qm-2|0lI8! zt$S5{G~P=;?D7n?wL^fwj{xNwO32Tjx`0K83Gf%qdsS_Z3=~`bkNSIPz&kd6!A?vr z>_Ttmf6xG$2qfPqfsU6`J#zT8-8V2REoxu!y8ltn4!*xy>IdE4o2XgR2UHH{$`NLO zATu?fd)n~>jt7800@?!U$mpo+vgd38z^GKtpcQ639vC&VG)urWrbk6!Zjbq~|1~8E z2?_v^#l~Xh6@LQ#<^zMgP*TkX5>i?()Xa~=wWMj zbpXB!i1M1b;(zPXrry?B0B{&@))+ev&mbU$Bt}F;=pSG`ek(#`1n@{+a_q{B5288a z&PWXah(Mm82{lDklG; z&n0>Fn+Pxzkpbc(eBJ?9mj@hKWK)eQ2f)tfa?mh)UFqd2ObrN(Vvo22Al6?5@P5i2 z%wU$80v*e5fF1dxSd9#vx**zSW{Ea=firC}96J71s5}IbC^?$ZK!lPGOdt-R0`$&d zp?;Vxb#}|m0;osveo*J7OlaQ!I@K48Y`5&l1D|f}D@|4;?eJ!_;AW!K{-RV`pGiEPSlzE*^t5HHEW#e6AbvkZvHnb^rV{_h3f6ealMgvxWc~4ivp+d4`Mu(&*8SJQ^j* zKv4i|ST~4H2C=$2RbUbUqz9*1H-;_O2N5g?$)nk+sWl+wCueQV0<^+izL6NXJa@5k ze>|+togn+7$aDjB!y4bSw_WtkK%xy0abkcu0;m8q(6b=lf`9-4sBS%yp26=B=*^Rx zzI?$3yMltJr@&#`kt6^pW~9CVPEgb3rO%S@c8)eX@gv0k32EHKjbIiI03AmSzmWy+ zuIjq?lUGzkEo+*h;MK~8_NJIs4+0`-65w!}PnGH*JdXng_7PW31+-sfoSdq(0>MNp z99J*aEdaWe%}1R#E}IEL;In#TL-+?=FnHhPfUO!=ylz{n7FRM8)LZ zdC-vY=X&qYfm#!XD90fc+Nr!xmdsD@GLWWdyD1peai6_$2EfkGb00;Fe|18m7uEcD5HsEJ5Pjs`Kcb@m_X!Er zw^E-A=$uEl9kybDL)+t*>Wveu6;u3ji93(iIrFbpMItbf=DOeL--oZn9nP7{Zx%BW z+YT5`_DA_Fzj-T@yX(}n$xd|6lU4Q9lypYr>oMG{H@yc%H`so+9dC0=iU*sbd(Q6hkm&q{fZJs! z#K*_ieDx=*_RVk#csOh{5{PQH19di^(_ywpmC45eY@LLvHhvwE=#ouMj!)}aUV5`8 za1i@4?t!$~LBG2rV$p(6z?ar0J=^TaZ2>_={4TL6taJsbKt(x9w*VOgAtqJ8MArv>|c*Z{!8}^bC7?J9!NigtREzer)o%q53b1Ne6*_)F9)* zf{P0g((E7r0oRqk)FzEY4;#?}?dZTO1!{ju_N(w4sI-#3!4}(mf4v8K+?1OPNEJr_ z5x5N?RjL3aE%E(Si6u=0u^3ZYi*7#oQ%Z3EUE+GYhYe$vaT8hRH$a4B|0(}(=b#0` zJE8j@DE6K38-OmP@FU$wOZ-H>o4%S^1uz>2z~YVr6UhqD!(?_?Jm1i`r3x6UFo41q zxz(pg9Arx&M635~Yp$+xn{v7&3S)BS^yAh)iopnMKxKV4(>|6~q0M^s+=R9xgi z5s4HY-HkPWy9UbbY~Gm3?k+&2dI{j|wb?~~ zyK@)*(!brL{B&#m%AJfmnfxtLLJnR!Jlk$ehtwkHZeTvS1fTTm6&Si%`+%CX2P%{yAz@l(sE~bANtE{zH+=yP!FA)mes0F>O!*G{ zh#L3_lp;;Mmp$1f(O)jSwS7xadAL5y0%0d_d(f6mfYDl)O+HE9UKm7=@wInJy zSit(cWQtECFt*#4uSd^mh5UW+qQ~0tCZ4r;Jfu`ITqmclPGQ^roz%-HVUYHle|@a4 zT-7-0GXvj2lPL?k+F2oOh}#vzZcN#%F(4IcO90iU9vl)D;&Y9`Y={8%4h5uSOMpm8 zptWgQ)X~;nBU`F`B#UTCUNK{0Vb!{<%a%^>&Kz><%yC`rSbs8cVJdvCBwO|zxqTeo zd}=1G-Rd}ebiM{FjoBbo&rT~M8GluO3o7C z)~9^X;|0X8*CGTcJc3!<5N zTarjl7!HE%AGu zuWg4v?M(i?rERTV!1;yRMD`N6ZvO_SpK=50%_xA{jhSA)uKkf&1-L3T^A0$)R?j{h zgYC5d9DaZ_{(u6RDX|L)r;eJ(9K`wR;v8J(<;Q5q>a7xhqVMVLK(JMS(Uo=+)xI3F zJKAH)TZKgGMy{6R;6R{bwBeu#DKtL?Ng5Oy;&$FbIK7=CbZxqKU_ym<4|H=kQ;w2( ztJ@h@%GCZJnrI-xPp5|r1<1fH>=m`xZgB#bUUMQxl-L26tCL;_7|qTp2kGDD>H<{j zo)s8zx_(#KxtnOwwqpvs_PI@k0cv@2&@U_D(?eb2z=!*H=g&7UzCr0kX-9_0uTWIN zD)*5uxVw-7k4zqOG3GofO#EwCr)-jW-547s+{#|_(m=k_1>o?_+gR@15e2d~fpj6m z=dTs@2DF_GaaBBeQiAdQGnQj5)TB_5S?)9W#<$PFj4J@u$B%O%Otv#W7MDu%z8UqK zRA2-H)=a*`sE}~8TZUF!(2G2o+nSSxht3paSPNmaSuin)Sg_MY7lOcR4Go;r)7vhh z5|C5#I7#S&E*TD6_qab5yB5h8US|tpal##BK@ESh*E#m@Dzgo5^`FfEL#ed-`cL4~ zy-uJyaY^ma@qjqTr6p?(?U`h;FG$pc(7(>Yk_;q4v44%SCUxk?oEs?Ow|ZPNLCzv_ zAY6MW@%wJ|m*k*V+m8NY7TNE13A&HXRXZ0=#`)06mPKLh(>CN`qTb;>$z$dUB>8SK3q&ahyH*7#?>6D9-v?f-yWJ8`GGoNJm=>n zQfPA34xnBDa@tkPv8}al>V>A*dHXzjI(<%XuUc|jwZr9Et5z%82y}rZo;wSLt@94= zRLZZ+_W?4+SjhfJdk1uqHx&W7KWrZ*uw#v#J4k@57w}4zSh|O-nkn#fQ#p%Y!aSsD z*7KBl^Y`@v8jmTPN!-6IT7{M9iNfQGCJ1EtL6Ss*yJJ)g{*{Ex9R4%QqFS8d*he~- z&RXh{64bZMWIRn{o=PLNT#e-@<&gX%Aqb2}A1pZGl!nU1ZP|av8~#Gu@51A~JjmLx zn%$tHZ^^E_zCj2vdze`9jaGVDjC1RKl=Jb!zZKV%d<@@_#BeP_w; zT3HxN@b)9$QO`Okm3u?9nRVq8>sdAL_u?n|y>3c%7m4mQvwB{yR`SZfi0)lItT}x9?WI9mcvbsKIjITeicrxzuIlZz_Ni97yc0`((djyt&ChI53Qv@^ zXX7)(PI6ZKMA;(AX~T84!urT55rxjG{3`02H!FCdATuKHF>X6k+=uRpdY z&Os>Y`$fdUZ%;lC@4!fj?`Vu&D5S|}M=rTwyARugewXSx5wpDqauzbbCGWle@u`%0 zagi7Y_nW__l#qPWFL z?G^qX!rm&Z$}jr&rKLkax;q5v?h;VCQ$o7CJES|54(U$m5=rT9q`SNJ~nFR z=iG3$SnGXfj5)@4e8!jvTSD)9(bGgfW{Lcc%nbVZ4>$DrOsmkFi-m$Z_H5CeRC?~9 zq2%VA76tDol6_#Rs@oQS|G z)iWoAff8ZsYkvxwvOMWSPYGZV{7Z4KY=r8{Q&=R2nn{_3)kaK-Fu&%p3VOZfBQ$37 z9;R!&y#Eu#9uV&_c~X%w|F}D?_tdpj&376X6YC6@p6pR_V=Y;)`Qax@Cwj)m=SkW5 zr$63_<#+hmij_(5M{M7|wX$#@smZ7_iK($ycs&-+BfITZ>hWHUQrULx;2L=}IrZi} z6!TRPh<$onkjUt;-Fgzf6i1ma>0M26-cC^?GmaO_!Q1hWTnE>vSRWRhjs)SM%aZUr ztg#t+5u0l2cr}`g9>YkK*V}}DC)_EXlD*laiJ0%&VxmwGIbPh`9&-Aj(cs&+p7czy zIlo$-gQk}%^W)0J$oRJ~#A$ye#l*x)#|xoEJbCg^T0@iDrMtg!8OaTa5D*@t9;rBa zbUTbfs($rFV$Mz&4-q$v+>^MMF4J*lk31fx!5}ktK4)Bp{a5d8 zTp)A)O{8fU--%Xq<^AAfsm3_z5MtW)r`+RHJy*w(^xo7-N9nsO5?Z&LuxKA_;-H<( z4W)l7Ng2!|`CgOZ+t#>$H~DYXBL}>_&;*0MO_nd)MhQ?s-Xta8r3ZG$~uNCdq2zc&uM;>F3cp5kvzE+N)32U^W6S7Ih_xKkhj z?;&B=D87i9r!$!+!A>I$4Qk!pXqU?oV<5rA!E4!9E@HM`m=xWK3Z?e}e9Did0EZ=y zMzQ7}+fD6eSRq9SlYOs|Z zgeLobemGt4g`v8l-umH5ckh+=?l=QGF0nr{D& zj^FFdD{!PHn_m;HZ)Z+Y4~1Go%Gh3QH(zGQ-<}Ij@RlEc-NLcjW=i$`IpM#=>%|Js zYY%yOde%uto!yGc???7_#@@cv$%?BcqB6SA0~=tyYO30P<}2~|cagB~$f-}6eD{S^ zh2$rE+(V0m9MmmzgYR}~*veLJs0AtsZMJLI2ckvR;$-ZC|$;q95SzId;KlgMm zU_&ABiwQKkdBG&}h8sT_N#R{4Wk!c@<}!71CM)cbx%+b>9RIOto7mNJqS7RQQn}4F zPyIPJp;|R1aG#R4m@9l7Sr{zuDJ+|_d3`_%46 zaQ(dQO01`v9BqTCGl$A{t!^qW#)3&>8bfm$gV%ux{x~DX5b@{pt`BoC-dg|ciT$Qf zMBBn7KK34TR@*;=C0O)1D0iA$34E~Io%^yQ0;aYk(_Csfc_Z%Y3hS| zzRdc+LTw^_?agL%6cVTQmh#QQpFVkf6d}4n;1$8Sw3yY_%TO*l-Z!e)WQ1L#CO>G~ z8kN?f-_mEj1gk(Khx-1Z*7Eq#vF+X!UYdG4T{@e0p8Y{x&wc9v^~9M^<1V-DygjdG zUuAE%iS!vN1+w9%J&w+~>jpHY!cnH;+dVG+{hF_bnkV3yQ=W_#AI-VTQGf*-mx>@y zg%Xkf7kRw~JNGgUMGluCHB}`$_b{WQka4_N`$QtRc>iTux_A2ofxwL^Q2&uxf<$O1 z&b5My>b}mYc%NDKHRq*JJoHlfyy07#wPsh2>+=zFG35rNOc1$OSjSqhnp8dSvF^wgL;H_c1VQV?z2jd^cyR`my}xCKsvg6|2l831ux_p8myA zIlLlg=0`8+P|TqBz8knpuLm{+9iB$@8}-S$yWs9c6sKfYIjJkJ6QP5J)Y*Y0#p09CAyaOQ;C@AaYLFB?FKiNH;wxe1aXe2 z5TOK7HSM%+J1pl}cW)gF3iZr+e13yhdq0QlN6n6CC&9E&u_18JjlWO(RDC78N*%y< zArfE2S2t^GQ1zqxJ$oYYAM}jftFr&Q96Qm|BJ*-)hJGyuEh|YK`DcJ-mo_Ui;1U%j zZ88?X>>X`6I7`C$MWr-LcZ}EQPvm+Jnoxd4g*WWULJv^n-N6h>cu;hiuSas7YRNcy?8W|K(TuLI zUE}EW^)XwWB3=1XQ^B+|<5#H{I?rhgCMf zM=`fHt0(vDok=2=sikYGJYbfi38(vg!VMladSPGfW)icgyu8<>w!T2g1o5y^MN`UhiPXbt0kR6Ld=5ojXLGfnr52Te-?w#eC!ZNo}*3r*}iCOR1DxM|3}Zs)-O{NY(7anbtt|JntzEQe9W0y zqiWLqp>3rr-Y@p3J0rFNq5TX&5$A@*PD?q*f!H5VFP<#2qy(SUCaf2sjB~yB^D@xAq(=glI z%E0mGOV(3btG6{a(^ki`83x%?2zoq6HianFZCcA|boE(% zqt2l>5DM*EZ(Q1MAJqB~U(>lp;0|+Zk2@Gx*s~G76T7SF=L=Otn1D`)-CFgq4m-r0 zr;uM4?W$JKT@0#GLNdM@-C@qFV<3D9+KMCx)>8ASB}eb647*QQwXHQk%TUg~ElZnT zNiM<`Eh0IXp+Uf#!T|}|Of1;;=M2tF32P)v+gRVvTMprn?ertp*1RtH-=8CprohTm zDI9{HiJ2(G)g5$C*in`mh2?-v`|#7)KxXeCDSogmS?=VT9I8q!IiSR_lxfG;)kP`w z1aBBz9zGh0$F{Jwb#dB*b`dJ{=nqgx&Ij-k6@>+MNHaBgyW~JoIVW9+g(u{|)-$(% zG-w3Q+i{XOa?7`s$t+UftZK*+Zk*O|wcT?j5_YF&%9_2bPaUpqI|_oM?Bm;HnNe~b z)U=bEs1gs!l|#2iNT;!h0**|5dw@h<>O?GzVgk~3Fc zC>bHb=>Ema2y6^Hk9z4<)DB}}gagumoNSv^PU^LSs~a=HaR)9iFkO4lIub2f3iZod zB9?N0U3Szhq%}rp71%Z8-3sg!7BxA>mDE1WpcswV(pL5VC7;aZ+UtHx4%fP@k8lc6QbX)LCXEI>}CH@ ztBoc@!_)5IU$3iWpFS!bhrGVgBtD&$ey{>b>DpfoQk4Lof9hU{)rJB5yW;?G2=q22 zkxGMz7v3)bt<7<`J_PvJM5CkG4i7TMZE>;fVqQ0;<>E60x`TyiSb2D2I`l;)u7WKt zAyMi_S_f28@;^VA=;j&zm=cvs|1@00@F=LfSPbJTt_E3zO0aNh23i#7k7`0_45p2e z5%J(8dUas=B{G~`@=-ev9=;=R{8v;iPCytc36!TsATA z3mrcWKoCCfw<%|C{NAtB@a3Yrz_3OCY`w zYQIf0shrJYcwVVsi6>9s?viqFmA?ME>P**}rbcS+wE0&%pvY&y zJaDZJVEJqgKBah%609KBQ6{kMKx18uqDgSK-@gF)#{ra^dgq&tte^Vl?Rh>p>KV!S zXu%0T{^ZJuIY;i3ZnJp@o*AVt=GUTJ*)t+=S?^ORYmTo++_5;nP~r(9W@7tZsjIX^ z0rP!Ize#S1V!V72qRt%)rgqR;C3%`X2`r1H9vnk z*rRGML7}0Z`KIXbFv|(u3FkAjI5>5DPJ%v5D(p0#UF?#cX5C(%_{X-$>VLmKWu_SZ z9Dft{a>+`mnUkMkkX;z~M8Y4n?}X0g@fHEc1IP2gc{(Eqh486%b`MznBED8M5MYl+ z?g%V!Y%92q&(|bnwJh>EFckr3bL$oZ#Oa``rmUW6W_}o(7k#ZnRvT2j_{+n}rP4l? z1UQ9US*VTC=m z$)Q%lfI70cO@XJTl!TGu2}MoS@mGPmdKpvueKn~d6Zop2%*ObCC9d#-VK$_8O2f-FFQO=r6YIir^p0tv}e!( z+vd{Y!WEdQA1z4lPgeOXeAON`MZ6+qmR0-VY9A?n7@P$P?PnbscW15r?I+gEXgR>o z1o{XaSTW(@7%7;XcI-`#Liq)D462K2IWe*n?4Gz_RYl+9cQh6seLYF5jm1WZ6Js_U zt9=Q*`pt4!kG)G_s5rdN{r-x(Z2Fi`8UhF!w7PdZ;}lF^$0mo0n3iOK$SHb{;!Iav zlCWI)++~zOgk(!~b0&{4gx9-`b?=MJK(~SUb@a;BzCe2YE&&jb3{H`T=yWf@cR2*G zIlz-}<5PYQFBftsBMJ^3aRT?|{QqXh(LQ!iXZ*X;M-AfkBwst&-GE?@94}R zI*v3lX74@_jVY#f^ zMs`)P{=j$YomI~gO%@pQkN8KUO#LWC394_BTV*!CA?)V8#l2sUS2@y&m^7hcw4@E^ zL6-W=^RVg*8nuGm7N0aET>2{Y0<$J%rAkwkOd>|`)}LuOiKJ=oaI0zxs^W1Vw8Ja( z_Ze6GLw}!&m7d7-%4aiC@^AP~kj?$v3K>?qc6qA{JOZ$_^TP@9f?>-m4W=Fr4y*_W z2-eq_@Xc%je>xAGH0V}uiE89xmgPr?JGV5ac6C*L-I{+Ye$s_qz-O@JwP2X99^0Be zkW;sth?q2}uR0#722Dk%y6`l18)7N+8{zwS%s;<8Sga0dS|9EV zxYLZ;Sn5L1E4%tzOIc*U{>HPwc(+r~wZn%Q!IKM}=TQmB z(76%cfXg8Q1OOllKY>*N6?F%ZHU|VE_7Ku;$5AqFwwSd)t-c`Wync3XQ2%JNyxvyU zVM;iE^MYsOh0j94>?l;5{I1ry%Lz{K%(aqWE##lKRr(V{Jk4Lq{1Z(WDwAB^!zmIz z;WKsLUYp{dCnxU=D@z?>?5Ad{j-$tomv7zy>6DnOu4GaK7^{ANUk}flYjGw(q6yoI z0GX3D^yQ2m1b+X1huA*3TEr8`sfXvsXC=R7J7)!o;Fg8!-279a=HKZ`@BZW?d3l;W z^~N~%^>@SO&>y(A8q`$>XT1Jqa-Cc#L}rxa*_*G{C|M+aMyTnRA4aKJTp3)wUC-&*7dV5md(L+rD4N1Y{ zoAv0D_zV$qE>euMy9Gyue!b)S#jLAnx&8j8PIJpI8)ZWu+qcogH`4=IIPoB4=UHUh zU*0_62>t!$+OSsTQ4qKAILDE%#}#pVQAy>IX!^~*x8P*%X+hwK(<^-dFnXba-uEyb z1W^9jn6J{4;b>C;&}=puWA?2QVWUnKb(vSCuc2){cO~j2_?B=o-aiDE3J%>r^d)$| zXyfK$E=wAiO_}PqCIHd($E;Pn(6dw9AD8-!VhmvrOs1~sewa`Ffg3~;y$fbW0oexu-|#kZ>5Zcj$sp%*<2>#dVz`{VnXkH= z?A`W>B2l~T0g!gz8HwVDrnGNyt^1s4vQHqW^7d+&w>>4D_F7ASsdFQq@>@IhD9PGmkBv3adE*M+Y4FTMbzYy6GswBq^e1 zCn8fsO906Glsyg?iYd5ewJ(d-(6Vdpxr?m%bN4c6JEluZ>C@-+WS8>AOh=J9zqp{C zm>jQV2Q=HYgs^@R*;PKhT%gWAy`$myQ;4y|mv#To8)(Wr??aX`PU=n^U18-46Kp%4 zk^%qSxm+-A_6u;KcmE=oTkjkCJa3$GXc21;@qX}w>(T{p*PzjCi+rCB^mV+>dI5c0 zv63rnr^S6cI?G-6o9{|Ho8|WYt5bHkKmI~lV&w|3Z*c2QnAt;XX2S&aTFQkxB>b);q9eF2WLp_yJrGjW5QBf zynnUF(VUd=r_U5s(`rg>mbwQcRz#HM|B&M=f0Z9a>;49p`8M6}m(JJx)NBt84hB%0 z$gs(%8AaOP83(r%&@B!>s!)Dw_4%5sKrSP^cf#vk`$jaM+h&}Uc#cZfXhEOb=`-OW zM^ZbAmvs(r@k&dva#Gu*$WvUm9GBd}plleiJ5P955Sd_7E*k%ZkolzuvX=$I$G5%u z?OU#y&4De&e7{4d^QDwmWNB^{o$H$G(Bn>ydERn=I!3Tj;7v(F;EUIh zY5BA=kKGyssKBvhn;zd2&Kr--$av5H4K~0=Tx`G})%`24HA*7acH&V(kzC15*~KaS zF*2`&TCp5`!?TmcLBoa4+W8>yN4H}hs~?_8DRGS_!#HCSe}KvzwVJ!%P*OAL0LOz_ zrNIam1`cEsV-jT;;KC$hp6@vDU;?hmky`^M111f=-jv?NI^P^n>Tn0|=HMuL+HoaM z^Z+FhUS?Q$VwYC?f}{84#wO+D1jOxO_&x|A50hSh`clNF^zn0-NnIBP2t&t%`vQ7K zlKU#o7(T(4UktO2R$0y0_`y%5rs4viTJGifsm06MUzl{4vAMRaM4&3+-wPx4dNbD1!hCj$yKxdB6yMe;La9aX8-Tlth$n`9a{aH2F%1*X$1#;_5$>>t-iUlCEFG+ zTT>`Nsek&;8r1;=E7z@vEM!J#N)0Kc|cQqnbpJ+5fDdWDAi#180HA>9{tD!~6bg zFo<^HVTf}0hDvhL-YK*hp@1Ag02zG_&r3&-+*90B;^Pe_hyn3q+Yn&DYd|b(tIb=M zYPj|_885MRR6h?32e9e(782Dby*^Ewtsw3B1QJYu3)$RW*fMh9@FwVO`GU~8> zZ-SPnSl+r#XVmxZMAGA5UC|`{!%hBXo*rx;+70?aHB4eh!H+McV=gk0)W27OBYNSH<6R6=2xKO9BlDH zDnyzL7i$8+X1wB_EnsHcqCS#XV1Z)QS{>LVcXxN)_KIv%3hY9uIUJKbotBMs2 zICUm{bug52T+5njJo}cwA>bTal(9iZP(c6+m7c zu#KpIk=1Q)0Hf>}TMvg7h@tK5p4~W+r(pbT%(AG4w~E`Nf)HE`EnwSrr6%lfiXSCe z(RwaJ{V8<}tXLpgypfI@hY1EQouu;|bXx|oM51EOOofuIXYvfQkUAP!xD3Xmt#{uCLo&^E8i6eYlTu;e17E@t2{-_u#_~7*g9d z;+0`+vFuuuq-XI2+>gDB{wwiorfP2LC5obpr!rx>({UU zwhByg_NJgskIj8ulU0VRx1THJE^Wy^E8H$tm_i<$Wd^i18fHEyXZhL{01FuwG`{B) ziOZ7v-!}B8@&3VZK9Lw7Ojdw{LiGl!l|lsEGv7r?%KI^}*PzGnMAKo)n3ctLh5zun z;^s^8*YBH~o8VW+YF}5B_|mW9;|`k0j^brFJ?S`vk^UZsf4IpTp^I&qh5)CElf7#H z%e^mOj+;!!K`?Fd^Z0Z{8Gi>!I*#QE->*7zw# z*{kqR-{gLl34UOuQZWAdHSYU&yx0gO>G(0TpdX~@m+X;e=0G!85-ODUOg$Q1nW(77 z_eS4daW5>>Zr-#-RBe&}+?as%Jyy_E(KjF7QbM|=?`~C!J!DSB0((cS^YrIRWeQY5m_F;nTJlL`eIifLv%43YwnW%bOJ!x>|>gAsJ_-CQpGDoaN}{ ze}mQ(bgVrub^g?1rf$y~lC5pu@5MalxNl@`eiZ0E?Yhpsj`Y%ldIl=c3XFDg7X@Gr z#%9(2)y*j1PqMycYKNo@13&`jNTwU&>iM1&6f!=`1BvWOE;7%+hTrQR7gyi>7tZ`*ViO2SFRkn zMjEmwT!pG5j0`&#swick?xP-?ECvt^WW6xQVm-U=G(i}zK1hkyhSs|LfuAT4lKGu| z5nKN^Md39{(sv>Ns?n0Qw~c2ekl}xAZG3Cwg06CYb^`8|;|ayXijbik!C3l&w5mC` z-f(nqqnn#9ircO(u(LG}Q4hOMi-o#bvj(Qq<)+RpLrM~AEC)ZfiWvR%C~}nZiIZkQ z8CMP=(2|*`tV@PGZ+_ozfea`ayCwI++NoMZlq5)l44cQ~;{HA6l=4OieZJxyQ41F@ z03qlv7In{oRbA#~iSysu2t!1F7x+fH0P6?vl`du*tDJ;HKl9ZkhrD^ zG6joVBr8?B!-msEt_}M`V7)we$J{kT$8S2Y^762Sh$OSnvZ7ocQ(ef@#8Vo4xGb9@ z49JZ7aO#2?&)k9Sm)ju_n#&buzp-)Ie17AD(u&c;lzUP>8bG;TRGDM)qNY+BycVfL z9va<~ntJT0^XIcXI@>-XL>WNQ{y)G}FBors;so8z_N3Pc}ZT6z_g<$wY-g=OVDisyN(nlGP*6dzQd8w56d z+?5E7DJmtqC2gSoJ#!i{KoQGuAv8;+Sbj+VGUf%%HZ71E?d_z$yuE4+QH%RvrgeWB zpl4542I*ji0$-|Lxed%JWE0Eu!Br$aNBG*>8aOpc(P2U^4|bpw8m@(~m@@kSU`XY) zW`>nFUrD35(&}pvDGxm@4Pt2U^k6-|0%*|09yr>n-)C{XRmj~hl1VME0Kr-R#>J41 z4;-%~pXfA^wQf1WNBkzo%cch5*gmp8?J!c)vS0#f40}wJO{aaSC1lFYBmQAR=dyOqZo6J4B%FBbICca z7lJ0|uR18AB*i*>p8AImTXklD0^*s66hT(Eims4x=;?`iZoBiUI0IeC@aEv=jxa&+ z2m`1=u{(ZRA{cGZj@cWaD}8&f;ZnBc3=IfkW88X zXGm6hJb=NWCD9se8`fg!W3uj_CXbe3a7d3ctFqCUm^`obc?H zwfC7Iz6`3vsV%e1!bx^5{KO#b`*UJhKWUorg#B_&#&_+YkjYOx3BUtRPVB%xJOuup z0FJ@h5!~;{R6j$+UNz_L@^3#Mw*{Tnt6w=-AX9r>sz;jKFY%*M4v6it$9XI>{$u_% zVwrKe1eU+9>ernJHoDciZljU(H!tBVO);YaGWhTDB@AW>Zkm(iXQ98h(g^jW#)@3G z**3$J>>KaMx_$xvg6Xj#08NZ{q^E~Dw(TY=d~qqmpnZMAdWN2-++BD#{|^h$ik6Xi z_-5q~$fb6)=ow6=OA|kURbjNej8k_O1qY|;fS{cBXK_)ow7*~JMk!aG%lB?b|Mq&( z7CJXoQe6BLNFPYhmBM3UW2rp|2SRUElDcoQeyi!qZH-I^nV`^IEj-NMTG3b#>^`1HE-I&%BGcF&Ct@e^|sep{=!uFE$Sgj- zr``V+eu(&(^|*w-MXxCWFl;SZKua(dy2XicLi=p|$rsD$6Ek5lFXu?}Fh!fFf+02N zDG7b$9}>?+j5cK&ALL`on}=w zFNSxxMd+z7OYCpHc}+MFKTT4Pih@MMTwV9jFQh8pvq_l3oiB2uS7x^7HVdyr;s(>N zZZl3Hdm=HtJK}ly;PF7&l~eX(s3JMezZt(ZN0|W;Eef=TMBC@~=}~YMDj`S@S{=<2 z|IvjAo$j7-SH^q%ck{}!0ye6|rP5Vs^=2njJA+T*M)@Gk*$B}wwD&;4y`^E)>tt0!?HW~^*I|8d&~ zamBm)scbWpYI*io=|^!`I>9913bwY!jW_|)ZLYM)TG45LjY8TLA-8YPED!H_BQm74<0;hv}H^gs?(<%7?rozem~ z6QgV5ssM+9+yUAzROshzzVp9^vtChTA1dC^gHx`a$5)*~_RBCJR(^Nuir=)01_v8!cc#*SyU*-T@4j& z$bR70M3VmPXKOnU6F$->Nn|meKkxum1ln({4$|j=9pgU?FXd_8HY>nYC%mzVMd}MC z$PU{$J_~0jP?1NndPJS%6_*6C2}>veVP@?UjrBQWtTEX6Z)uh`lYd4p(03;3jbcjpc_U@So_4GZ)Y!#Jd%~`Ibt^+9z z2+eOdPP!TeQMGXJViXXGw@)V-lih?oqdNgAm1;)RTbT67f-ep<=!RMWs;oPov1B zpvCpws;}4{nSv_vI!or!`^cGP$Ut&T3`(O*qa|3oBKij6#6IG+^C4y$W@cv6;%F=h zGBUE9qiFCPI6%7iMr=<95#F__Oq0aU8xKm8#AJ0*C2M9N!r>0=S0xvyEzpk7Z|BpK zQR?3NzVAab`X+a2(b$kPq=sYlCpA{?qlsK;a}Ce*mYs^HTh^)|#Q)|QY21G=Shnb6 zMXDvkbI|GqPiq<0Uza8NudlB^0{#hdsEc6&7O@%MU=4+IyG{q$p03EHln<+b>Q(FG z5x@a%U$UgXNh!UU5h58aO9B8yy`zzcJ}hmR0^h$#?sC1l{Dh8e)U<_~7FRm%$%8TY zk^oV-XZ)#r1no;@g)k0YC==rh#cW)k*L)8|k&w%T7o>unR$ejsCpU)7md_eQavlo0 zxDcRK36$ZIwdl2h4xoTc-vJq8J6#{Lq7J4#v0_tyIhu(-8NmKL=U1K(z+;B)TpHrz zS{uOK;Aalsk9mhNVGB>D*#oq-nUhScK!nixl><#cWhL4^R*u?-1R=*7AF9%5En^Cc z4Ui4*H3eOh9J3-nbSC~li4XcYQLB)im*e@hn!}7TH!x9+SsLjJpdW*EgrI)`%k*v3l1S=X^ry~NZr7JYyNPeb47!> z#J9S!vyhts0fKF9BGX9ut7YP}nSE}@N*@T7hqH6*+~e4x;^OimJY5Ye?n7eUs>zwQ zR@eG)@n3pfQz7>1syQPowOC63(E`wm$8JB+9aP4q!4U}|owYqk6R~u8_obKb={^|T{BRhfHjjgkL zBeh93OX&TubSX<{UQFbL8Q{ah=UCJ+1P+zd>ETGkV_K1YtK+KkXzn;psdZVG>l<@t zl9c=giAuxp1v87*QygwD8ojQ=Iy$%fRke4`A&%w3i}ffli1N)Re;p{~tP~jT8F2js zq$QAHqbV`@uJIwRpZ3T~Zfwry-a<#Yj%w{UW|e&wUgIR=f_0;E{ccH^K;WITNE9Pk&>-~YDZRHSjiES?p*qh zs4|YVI(M9P#38(`b0wR&LFsx|?{!&Uv#o1T`mldQ4>Z6$c~|W2J*lw-bhs6W^^x*} zD#uf!7oNLo=1h5@ZU7uo)CjRlPrLUSEtA@!pW9GUPx7g2Sx?4#qSA)KaLirL5!>ey z+1s}<_Rw<$T;VqBzRetO>_F^8(o1;k%8bV)6jq-*#eimW?QNSIA=jzpGR4r<&WyO? z87?yFCDPsdnyP%oC*1f)n|DwEn*q@40>gutzi7?U;m`=`WT{myxUQ<%f3|IosRcADh*M%t403qNEzTepbBa%xEGz}kX*f5 zVOn~xWHgMJ#b4wirRMngj98}N524&Eri5HiqvjiY#dVVxBTasAJUf5(h#a$H04*8M z@}$_|i7-kcd!GMY@6UQrL?sOOBCLQg4Cv{(e#Pe%LKqFTtc4r8Ss@Df?tt;ccydq?rD?-FAKbs(uL7uSnPWDMJkzBgM9!<^;M-^6`w;R+)#u*0# zx7Vz`6aQCur>d2f$Xw?*AR)T{Xr`$f00q6hU z?>{Hs`biEd%B)M73cs99=Ig(FpqpJ1J8V){tUi*W)D+Wri39kH-IPjpY7hT5Z}mL` zy`+xZ>ywQ<#s26nSFzR(7G#PVJ2Wu+L^g3GgK} zUNFqW$N`PEc+p`wYuZRCG~DIeDhX!Kc;%5&8q-G>knv{2>p@>pT=2dVxwtU>7{4}# z6izl4y(-^XvhF@ssGd{J+QVb733@`l-uDK5M?mu|O=0BL@$qq>PsqF~<7}zGzQXK$GuE z94{FivUBOQZtk6*q(0%c`2P`hfQca?1|$Mx0zjFpk89BGB- zKK1|JF#WfzEll=lV4?Prdc(G*)ofA@N)$DaUiqG2==5>Nw%zj`5{$O=$j{1sC zgX*rs6+gN*Au{hb6bHz2^s@60J8u@R=3W1P0lnL(Rs?pBs7PssNz^2)h@P-sDGxY= z*spt$QS9Z##RfBFnp3l$g@1XM&Dilr1a5RdQ0he2tuFvUE6E0aE^yyUITZ@4{Q;(k zo~4J$zLDtR+Nj}gR!Ca|qL>b%{f;~}0H}BKFpCzNlCN&`VBGlHj@{N>!Uf%l)8)X? zOsV676eq?cK2AWVcTfH1Ld2Eybvm0p z`^FWP$feu%&$-X2m{8w8QSjj`hvenF13y@8ROxcmLnsGoL|TKpeIR*TD2e>oj(&cDSL-U4SxHCp<99x(`Rt_yJWd&Jy zUFzu6O}7b7x-tV$f5DP(f}fxFT{u(wYu{{14*2n0(u_ zz(K|X6cNxCUa@$oj=M3>3!jy*)1F+_&(CMTnoAPkf8HF62*dOKc@W>&5MO>_vj$2a zem%%WU3^?7&1xebNOi7B2~6(kGgVc^E}U#Eo3hkNDK{;lR8{35WmV}{^v^a$2MiSO zm@MyD=o=o#OkN6h6B0{LTpEu+y8qhu_EVnCcd>szxiEQ*32Owu{bOFu*|n|sIGx}) zSAxoSa4Z;udC+b@u$r#;_k7mG04V82HjLy8=Nu}^Rl%T+K6=sy6o}Kl7RWQ2Vu8Z7 z`GrbN*HKtK4x^i5)+8t(zz&In3LL1jBgvWavDD+7!_hnMj1ZENZsESRg)ZkmUJO5@czKSd%WdX?#Xyu(nqbDLgyJLF3Hd`<+?At1O_tamIgU2viJ z%!8nN}&zMPI5*X<2u*;WeXjbCoF>QKvUP<$U}yF)*h}gIc-} zH9VUU3R*f?alR3+`u>R;t2YKI17uJsZdWsQGL>F9I)wU4KRa{JgNo17kpD2DacMxe zbRS)L81^&@?$oFyR-NSxDLQ$GKE%N2u$B*?w{ACJ2=*1D5-eslA=qJOvPe=XwO>ia ze3(qON*5SEVFQ;fS^au3H2#0lGvFjS-*4HA(|^b)EFHHLHeWYXFjZExH=OWjZ}}=M zT3FN{zavq?*o0W;AM|&RpQ0_MPU}5LXhL4BO2S`sU|egG3J-*3XFxG&)RG~TLFCRX>&uL@4Fda&ayI~uHC)U%r2O02lgo!sWj_~@r*@y@3N5yj=r zv|5rv+tOSfo<#hIgUvrsENdzmW?#d;v_K4M+X@PJPvs_M_yE|G5| z4{h(Izf8#l@nTz)$yv!wdrRZzCsp%h(7tu@PF|QCO{7y!<#(@%&3g?Y{Q*uvB^4oO z9lM|JG0KZZ^YhxUvY4L=+BxlEuBG_zKdWU}iEGC0J(Zlne-7NNW-40y`Ja2|O3F|f zQBw>wD&r8QWzDeh%5yX0;8u!?Xu~5Spx?!jZx&^jV!-M7uVKYe$^WdIktuC*Lcs^G zUS+0|UJ@h;87L`|{5XrojfoER`IQ%GkEN9Vu4!#=XwQ9?!t;qjHfR_7GJNRaV;t3G zc#vA09rstty)CVAns^lQ zydhdqnGb<{9NA@Rd7n6u%*p^?^a+D;O;r4AM1LYmjWI%QA{x44NR4p}xo7`MOE@wf zx#4%BF;*?LG|s>Lt0d+VA=K6r1)nEYIu7mytsETAse>=lcidCpkCj;1wa8{aFQnsb z&|=dto+?$aBl5slpDj91rk4r$(U<$N*TX5r!3TWdWetb;B_mGF^_8trkQoC5n+V?> zk<^UH@JU-L)>)|-f>bjnZ6JQKkpdf936sJL*6iCNfyg@-Pma*~V!E;LGf+_p}7SIaNJ_p-rQW_Rga)_a*9x9W8ah2)4Zuwta6flS?Uhq() zv+&J$`C7R2&03Hj70r+{ikimZp_Ha0^TZ`!Q_M(f;w2f@nne97X-NOpC`QUqkV+H> zi_-^4F)>Uta6lQ$97sO$(u<(Srl&`SqLD1L&^1jsJjs%ihsKtTbuOwlo)u+6y#7zA zIE*dk!I8<&@qNih`UqiCFe-?_-E{wobeI|!Z8pWl6X|r{lj^?jRHbKK_{I0RJUBur zXu>9*kV}3NO(9;JMR*_%A;1r`$0Fn40QtRvanLdSXc>mSTVsu?CLUt>Qt7)AO!1ss z9&VEsl6o+F#qrzJIb{^Tio#3jX1Ega*hDBl(9*NosPFV}p3ResN*(8 z2mNezWaya*vqia|DI&_!&p#}j4CN60iTL>@kx(u!0i>kV)NqbFBlWJz4p*zbFJR=w zHzC-yAO~bn24yJGBUop1XAy-Rn&F6*FqVEA)|Q?XSk}4ysGxfb9SuP$AV7rlj?u5Q zw6rZPJ`Z4hV&&U#*D4>tvkyW?WiMVW_DI0sg!Nb1#F0Xn)4yN(={JkSm$jyJ!QdxN+KB96Nz`4+#)lg1bX-5AN>nZri;3dw2i)arf-*96oSnZuhNQUHw#b-KVRyywray@QO*$ zD~dSQa=A+fbC2ujE-nA zHAQ>2nN*ugD}Ja!M&oAqKU`#br7>{g*&(;zNl{TT6u?6q z-9LE~zxD5s;l)wydXJ;>*v(3hO*RK7PYdHa<8wxG0-W^>+oO zQB*7da+dj>Np)wfjoUJe*~8sE-LyZ!-p#GXP1)w}+#IoKU)=K$+C+xZ<04dp4Tq(_ z)|fT1v9Z7O#?s9-5OOP5n8~^(D8=AhpJ3h-U(^BR#j#dHGb58?Z?CTnKpja*8JSAZ zTZZq$?eY0}rQ6#)WXu;YYjF=3NW*x)C4Os~QhzW8qZkt_75Og8&zyrjd$nk^JCL*- zOn{-4rDUIJO{q>YCnc(_mFR_ofP-LrWi*CKlCs=5=ue5>f5DBB3T0iZELa=(Jj~gv zkyqbEhdL=y8zFEExcV|up*?ud?SZ2J7w}8x^BT)Zt$UL=h6exPx?fQ8d74U z&mwKULeOh!;Kzf3d@o7c%F}c7eEpF)-DEKM4eo5TNE{Re0eGm_x34A7`#sq~;d~Kk z@Q~neJQ)8@JqGNgSZrKUIo0LtzfPdSW2ME&-^D=y7miW@9YsZNaLC9akB{xWcmkn4 z^t}7D)+R!r=CCj*pLwu=24Fpx%e3zm?Q=zUk5}MQP>^-HflyRbE)YSlPn=vbISEvi z)jaA5#o={Wg~tKKV>8|x$w`Ttkm9-bLIvp&T01~R;5QiMUGlU!pp(g7BL3u9Y%I-S z1U-&ZBZ#Syl&Gjl-5=6y$a7;VYj(QVROMu-jV8xNN+@jXuc@fYgY#>9pL^-y$T6Y? z3dxuUxP`b0OZT7K%_r}^K}hD5rH=#T`dMW& zGinlm4*?Z|2iBjZLK^rK3O$Xa2FEfSW`-rjfk#O;w>>W=lQ;qYZiN4*I|q2yjalV_ z?iL0he$ED`r+QQ-IQ|4G0|r?%PKcyU8ZU~$NC`YfoUU!A+Xb*UAS&l(7Eo?GR%7SS zY=Lag`#n$^-(=wXNP{OyCpc}$6hpn;%$E{YSsaOr%YP-HMIiHeF!Qur|mg zj`55qFXMJkO3<*dVj{Wz+q+~4kSyl2qBTG=X1=U z-+JzMYPYX|5TI~SK7|K{4{JL9#!X+XvS1U87kTVPK0?G$d`Y>PRFeA9Ja(tV#4Qs z6P1eacg}`n2jjat+Ugb8!<>^9lSs;o21$@hxI7JNNDYO70`vPpd;mF%vDtNKo|^`Z zy>cfM&Cfe^#C@BGt@t~du}X3SnSkL70XX=8jg2ll!s?c@(evtC;qcVm2AD+85&c!DcST5Z2@JTItRgpgwdtSC=A#oX%^b zxN`%N*Hm*WdAba?z_R%0LLSxov*Ny`sAMM)BV)!4vPCL6WR6?JFl{Y)P?i5OsSEDZ zWl&R!*UF4=5AEUOwFyZ4M31Y^bc!9r{Dun_z?ifYIzZ7GME>x++SO4#(V=cb)J7R8 z0I4PP09@?`z<(sKGw#1L)@`!j+U+k>II_IdT;9k-rJ~S%Mv2p4pb_wQ8Uh!YSF3CV zfO%bpq`0*|n-J`fMwGYCOx12okJ29<{i>=xIxSv|ip%04+NA+-bdZ zyKne=;+B>Iq~P4A^O!|N6PCH=lbVe3uRk^FE>0<0_9L$ExYTQA4_Jqi{Q8LNekN7? zkBVIt0TS(CRq)Aokp%)Fx0`$FyLs3eno~uSSN|07+`Fd->B;b(`ytBhyP+kKLzGj! zVCWm6brfbLRjn*~uEyY$-%**uZMxl^@~f1-WEdkA@X)mKTI+AsGLWyMxG<32XmM9t zbgY^0z%oOUYiQN2inRC@($LtUTHV8Qon|6jhoN7VTH^maJT~cc9^0EnfdP}t$BJeG z$Gcg6mx++zP=%?=vnV$u*=Ukk)Hvg}kB2Ib27<9-ie;M8s=`2A%e!dQ?L5A4EZgbP zvpnaM%ov{A_wEd!&HCLbrsrCW%_o@4e9RHODeoJrT=yYtI~RqxI?ugq^UfT=43nvB zKw?nmeVJHH;uXh@@y?My7}xS(E(E=ouD#^i9ZDy$?vh5sP93&_8cN%y0f{{)d>aca zI2rd|G|H5pa)!pNHRlIgBsg$aQ;+#+sP002`?5HNPQ6ICnq(d~ww^WC^4-L$Ys9;} z)8%L9!hNi0+LmXhsF=Spo_jLpOc!-BgFG>0dkQ)0K8C(ZMkcg}#r>cx;$ElQgVmVW zqbZS+st?0;bVkm%cx`msiRo|_N|Vs9jmgKK%XrK|U&U-4R<)mB%?w+iO;6Ua7sM|ChEsUM$6v zHf!HD0|`>lm{&C*u_+rNkca-Vdt^iubU(Vsdt{1&cw`FAV>?b=V7HI6Y+D({7ai&c zH*WN(VSRRfLlILaUyOH{@E7}Ek>Ke;p+6J3PGZD*zEFE67yyV@B;x}e|JNZvp&b9u zJjwsxuHpZ5R?nVD3Y(0K%=SopJD&?yg{W=+ER>^aaOtgV!k0|1>eY?_QU!)lMfN-t zQL30;V}uAqx1)LC{Bl9FVYeq+G^sd-Uu#}kj^XcBgHdy%*7OEGzpuqdil-ud-5>XM zuKND;!nOAzY0D8kj_=<050=XY)upX&U=s0^Le5X}M(G4t2u+9BS_3;ACV>*E;Kl>9 zpuW|qlg|Aa4FjTDu%ykr-EW!1wfm#)Wo_o)y&$nhvAT-mQ&%jvi&@<;@k#M*%J(8bL{a_0aEp)dyV@{k0wkb{*Z~ zAG0Mxa@Ai*>N-Zbd8}+O1ms#DmbIx}aqum){L#IpmLr`H?J_*la2eWDXAxnM=JrT) zqT6L4qZ1I;Ogn}6SI}Vq4KahB)(k(8D8S1*%Q=!6Z8w`2|o=gSQ zf5@gn)O|qM>lMOFOjNC+Jl55+ZWp3cE5jfvtyjqtF*kb@%fTKdm{^DZO#5- zm2f_+E~4Y^Bw8}7u50nWl!guga#G6?Lm7Fnbec(MjMk=Ob5vHgv-k2eieKquVb6O+f2~ktuZ@p} z@?(6rfBM9crSw9|y>{#Q$noPhFN*vus=Jx%x+$sd7$%>j;&&XwcP}W-!pXzOx?dW2 zKMf_F9Y`V*jfvcypTJ#P3FB7`!?o_)1yi3q3RDcE5}`K}{X?*d=jg6pXdw6f#W)IzMRNTiuoU6PYP@kz z$k)8p`?10PxaOj+`e3TFg1y<#JbZ5ot7)z|(rquHdnfbG(dSoY|H0!%8miy2%{o@$ zaD(Tg+qC(n*8)))!Q2!R`~CZ3eTip@R;WE!`t`^&^&2(`rpsO*is}Nh1`>irO`Q?9Zu3tR_Mze7dJGU{cnt~7W2`5Y zJmdLI#l7V___ql5<`rw_OBP(NFV#Rzkz?qB1&257V{>Nz#@H(s)S2+?~tq1QgW{- zGFA&U@ZF#4+z}PkS&S}Hg~}YSEC=+e;9?{$SAosPc1zr!Qm-C$Enks*Z%Fu^dg6{BQt^Ty&We84|i`4=CO)MSM;l=}Y5064eRpQeX9?Xz*y-YJ+=M9)etGRsPxqnVQy!$Ob{bPAz%dSdz8_yK+Wr$`;QTW1Y zK(z8bx384Th`sqI2a@$=TKguB`pLe-_nFnzI%Nra{#^ztyhj6WzVN89E<5W)svVKW z-mNMvX0^eT051nOCM)BbfK2;SGY)RDy6~%6m(H&uy1&&JnC|>vGSu899wmP6+!gS{ zbLTfJnGJYI#Ie7>eQP6l!aD!Hkcr@RgI6&)ggLR+a`JGem~i@?St>n8`gT^r9|krR3qJG)k5bnntmjB7zh!k zbI8}%!%KQi_>_KTnZoQD!fgoLw#I06jG`r{2qYLxDU-ZEiTFUTX~RrD{X6kWP5Z zos^sBsPV!S1t~Cnf+M=KD`F%)9yyn`Xs8%iRIFaZ!ZNl`=+dT!GnNVyz}n|0mAL|h>uN_#uxJ_oS4)ED+hk*_zRKT z)IbPisst|rwnV8&yb2-&6{B*muQxd3fnzf_*c2E5TOA$VKK>rDE?)BTH@)-io+@-T zvR)U<>cOqGr+!ISvi`LAcWe+h?dya@`;T~s?fWY4v?%DeL@=;LPq!(@=ciraCxPkF z{V))v{2dp+N#^8;V7p4;Yk9_J*UXZ}7k`j+M<_PeH}Rp&@TSpbIE~L2Qj=Mn;moIB zm`EWB4eRO@lK(A!p#I167qw@}Las)Iea;I1E9tD@tvE+MGr4`LXBX!w zYen3>aH=6~D+tW}JW8us56g~$jkT&jM7u?uv6=~>HoMq1sm=!V@pg?D(HXh@L1S!8 zphUm!>zUtA;j1=d#iB_@i}^;}N!DMn2en@t*pCY_{oB@ZjrMlM#6$BbkTtmc5bJ#f z(IQe>g1pQY_r#Xu3BIw`H;2hD6@*0Zvi>#}FUfyveWfYX_+V%rzCALzE#4JJ-b@x; zm^fs-7So?p<%$1ytEeqf$Zl<>BDa?KvnM^e^!?Qp-~MbVEC2?BmGDr4ktXzAPK$5p zRx|i<;N-Y7arp;A7Td$^2NC0nOh*`lg;mYGv2=Q?mPzC^bdo^7$9K;t(C4aiKKS&{ z80`O>^kLRE-uC=6t}l61S6Tu^%&RK+Eh>NQxyq@z129^WMAa$3;3xyVg-n~s=@B{U^J| zyK??Vy>~A@6ODHjzg0h-55CUI)o@sLV_JGb$|+t7Nig>NuzIFaX?yONI{ZnrQ@M=g z7*=>8*v?>MXbA?A<~KiJQeJ1fLo)80O{r%&PXhV(dD-G;#Cx-80vT3H7NR!Q!c2$h zhCIavGVbCeJNLu+eCpbhqowW9D^z%u;xF!~VX=o~WO)P{gj`lr?a#Af=O?{fOJ37> zCvL16Yj-RAiUx-@hi~#pRqt?Cbiadrx41XIIp2L?Apx1&*>UTW@6eJ~$CKV8s#i^M zXPM7jG^QKU$L1fw9xXrgbH9PaCF9ms)a$v(%#)NO1DmCn(^ZCI)y5zU>fFIW8u8lZ zg$Ng^6Q#=q=4jNIB#bzG5f22QevE-8l4gkW>^wi3=Qn3o_HXH@UoxV@Kwh{x$;-UP zrk6^*5SvKc1y77}WpwW*%$f_UpyY`s1soApLtum81 zukYx>22f6{s8d3`cuLB8PX!}eWf|JaSsRQUhpUrF%82z#W?vK~>v%1UR1iG%dRdH> zR*$SHdsc}Z^x>Y(2Y>AzF1zgEvw3%e&$0;pxHqNuH@kDGwFS+*9W%xytBzdh{02$A zjWs38K<7!0VBwD1M&N?HQ3@=)=9_#-=`eO(?|h$d*HEv5H_lE5e-4OV>g|)l!>b2M z(tKEoiDd&Zt@o@QZ{Mw2YoEv1NZqhP+}W9ga5g5%e9oMR*5y56u$!H+^WrBVv&cNc zn*%2Ignzp8QHL|%o+55ZT=cBN@s4V09`F^x%NneD^m}ZWHr3;U($qJ5+GJ_x#J0{d z<3jZp!F!F9G;@2>3G^X-Hw2&O@~1G*Nx6F0of044hu$oi(&k}@!uPw{e+D^Sr+I;< z;{*A^W#!*gW`6uAs#cBRiF-WU?^|s1ZL7O@^qqO$CnUr)b}fA47)W*hS@vH3#$PJ$ z!s_FAo-Rg;I#cWNi!q0ls5NN4Bi7PGISjy~XLP)-QiZ2K@(X0CiW zqMvYk>L>JDq(Y50ZLap-oTN_)lBfOT((ZZpoR1GQD=KxKb3ogk#*URZ@4)D{cUyJq*PvL7}eKpOgZO=+tT{UnoIgGjoV7wQxv@$WP+65@~EzeNe+uwZYl zh3#v85a`d$Km9)H`IJZS3KjxwaDtQ7u118kzgd)J#LrMXQ+pnTG5nox^|%bG;`a*# zd8$(Go?gc5H9|zc8a`wjyk}$A-7Kp9RchNya{n-8nag^-8&qW*@=^fFXzO-k2BonX z7Zw7sitj7k5mp;!O8Sw44`vzn?JuZk2CN4;2)#v_@Qi!Z#~cJt-@H7UHFtZ;FN^yC@da9BPvC;>J z@hWXdZTZ69CBTlJwFj3Ww!|m(haSb}(@XpAD7f%7&r*HO`Npd};l(KEPwg`1;{sKw zfs}V-N51nD(O(mQqU2dYYGfBL#gD)b{~6nPDiVHM7G6c@J$X(y@Rk}2aF^&oW0rPv zlEjz3edtbwgW3M}#V1(4<7+_L@<3?d_CF^I@4h{8h zbRzHEjQ^Eao#DAX310C>9?k^EtW;3q?WrN>D?PjhAvA22lA*IrXFBG!DspQARBW5; zoH29jITHh7Yt*a=kcXLm-=D&+6f#`}o`)AKaH;Nd2R1xlzZ9<6 zh9CM^O?phMaxaCAe0J2zP}ynpvEZ**e|%ktD64I7xvsc9Z;mW=cTM=q{H!1dVN&jC znE_BnKu**3dC7hfwPe)6Bb2V9Ne&UVk0*logK%L*nG!;Wu`ddy!+>4|>LX!F=8kI7 z4M%-tUW9UpYi8%&K}?c821J1}^H{5^=KXUZhiWd6)KkG=v_P>EsFD&Ug$qoy(hJwLiX;Cp{UzjV8Z2~4sou^dw$-da_Qi$$w;K3zMvwKoq95*>fZM=CqLqU-BL^4VBU<&sPHOeGC9%x6SI2wS zFKf)U(yaGETCHf-wPU7-dV5I_%|}A^gZ+d8jX-#JiS*{O!)m)tw-8v6`mv_Iz2DzF z{AN|9SGHsRCj5$m$`SFkRD9P#USZPvMtpJER0KD9R;HiqFd#(D(<>x{7ipdCtT@qY zD4zM&08s#|SaRpXP#5VEpAnZ$GsdPS-jLGN_}9oz=_{kzHY+wN-S7 zM(6fDrx~jw@_UCJHtxxzvuU2QynlXNk$%`~4o>8lY?c~c>HipsZbDS@vjmg;n5WvZZ@QCVT^IjrCzA(l8dEFUcZuT z!g}KEyIfJ4-nnvAzD9`(x7i==Fj$Hsm$rBECu}QZ&r9aFbi3U9x9Lvu&bwU5kW%9X zl;?GWo6x?^>4j0~tA&v8&Vd&nN344^I=UA`vMJUghiO3&3DmGXIqdH9?y0p8+TcUo zeHq$m^lYx(S3B=C8P!oc@o~%v%;tCte;(>Xgqcjbc*{h;y?E=_c3CCuV9-sgb9)BA zJ8t*IO>iGJKMa-Id2*O z`^>zVqTkbQD;3F?{LaUl^FdthIk^oq&f()2gM8v<%=Rb6mb*|t2+}z9oA@^Jde$a# zpUDd8&Umi|j{SN|zs!A@9wNM!_~H2M-+(C_jZqn;&;|2G#KTxi>-|!s@$Kyc)aRO@ zWUAnczPNMut?`xuRp!5fr$c^OO`~()J=$}zCI%Ss?sZUF`sd*e zg=`jSPrOzZLzP)t%Vq01k>ky4uSYaDV_WJsSz6VDpT31@;g**;+-20#q4YRjKWKc! zgcsF+!e^n@pCSD(7r@pQ8M)Xhi6KrEyvVT8y+d8}zWexfoO7}0+;^YU+2;=T&-8;= zH5chGMM**)_gd4Zmyp(eRWuCsv0*GM)9i6%$pq^GLA(AvcQvbQ*qV&TAo&}V;0@?E zVj(;HYxk{1JKKIQLg_<&Ir6zEiL0ZPO!+&ru&Sjf@^BSRjvN*p^9t^n*gB0}2CU31 zB=Fu^lUB1hNXht}S|6^KsE}%sH%kjy*OFM@SvPVq5eh!f)x5p}yyo{I8t!z?>OnmbV-XPMe1Zr0)UJjf@4-Z>GsX~kh9%%jN%FLE7`!=PrH?`6f zHr=yI_rGrdR$G)Hi`i0Z*`5~%3*km$G}ynTw)^?W)>O-1VXqpot;f1yKAsVcdla$L zDB!D-Vx5s~N>0yD)4m@*Qlf1h6P48cnv!S!=3lIGc}$AWjXVYp45nHjW`!x}=@CxgQ0kcSETPD)UKdUNQ9s+w;os0meO`D!HXS7d( zMXA~kCIk=?{1sw@w`4CNZpZ%qtdbh1=gQrc}9#3P_}mdiUy5aw0S_ z0PdmzQv%RwD&IFrCoM^a0MVIN)$8H!r4ma4lMNdFs_2mR|Dd*H7=1-pq@ra}jc_TW zzqCaJ)Eeb)(yG)PijlP{T&qeGm$om#cVoRo(!GN4BB-)`I2NSlVscar2296z>@lAl zMRN<-4`!7SRrr_P_6Roii{22fyye+H1KsMQV8RC zA^cr)_p12{MR_r`%nW-Ts$la2MI==T^oF&$7YUtyB~zfcYuZ@mFhiZ#?u1qRTmZG6S zKn?Ogx(}LuwugHGwc`V+W0;ZtW%LhbuD2$y*qR+~m}6k#j_<+`;5|NJk)~+AkCIg;6!+Z0QSn(3=#Z|V1@tez zR5pFL_NsRs(pJ#4k$n2#?(-8BKOefee|}e*s9S?R?=jc&!}4#rJT8`m!eB*3Bk88c zI!i}t1nDN+l6KAC?K+L7f>D=JpMN0!5yI{-ZtSi@q|zN(M3wRPFTW`r{{GIFPgNYK zHMg*K&ILpQfzkxuDI-10n!tVGz3M5bG6^6i;pcq>Mb?s%bwH$;c2D-m-B`s|_o7xX z>!!X|l>u8nM|15Z=}a+=t*TY2`h8qXZYT?6kr3LJTzWy2{D!A5;ylH-JG;Oo<;&Yh zJY`D;95Ti-_$jpF3Jhm<#fD|C&b0dAvcTul%$T}LJ)C;72jaBeO^*dguMxQz~ zPxypk7p(oY&tWR+=LNH)^enT zoeLo4M|O&ud`Upe1(cI_wpx=BQZNMKxcn5ml)1h8{>KY>I}8?} zVmqx(!YdC=0*VZ{b*!MhP4>fqp)7hthVuwXJ$g|P=aL$JZ=O@F58&8KHUTU|CN)JPex4Kd=x+Cm8j7|(P zQ1r=ntLU<{58IYLiU3}LS_IHij<;hKeH;XGN5-CIEIB z`OE$gw?LBjh+*oN>sYuuOZ5y&bHip7vnFX$7K@E~;%4tT8zg(`Y14fxxE_LG0Z4u{ zvU&knJSs`X5@9oN2sI%p7FQQC!SKqwr->0!1`Y3w2 zxwiqEtYmm(#*5}dg5LM$a5s_5v&@)f4{({a>HkCg6{u4R8`t?hM3YPuWioHyPV#{G z`gTOJ?)3*ij#ESU--%<5zemQFI!qQ6gwvvT$&nmcN~a18nk^s@y<#JG`z#2SrtmX$ zZED>etz?aFOn)j19{J=^q5Er#UfoPCO6b*j5&|x|7PZHCEE7E#lF^*YR z3r8jF;EDvFO+uxPA=#()HA~6|n!RZ*X8nQa1LnS3CwpKpSc<|99ZfDA)LzC+QSCk4XRz zYS}n$biLW54hsRMN_1Qyyu-{l%+LWh*G-lA*Hl}}ac&)Gd7B6xpGHXPqUI4!^>j>( ze&_;E`6o=4ovGizQqlhtJ=P3d$=W6%w^{{T9dOA$Z8u1CnY*%s9@Tb(<{^Bi9K!SpMZ!M)&RJ6W6UE4ylTJ2j0(Q>T9Uap zd|TPu_vs6Uxs&P7cdjmSr0a$ zfq0I}TBBSvkgQ|{{QC11Bx__#V||cXTH?^r06S*QX%@DGEU>`EM!+g=sm0~gk-8o5 znl<}o0absmw#p_sdiTW5fYkcP^qIwMt)aX;vF0-W4uLB+CP=T}vL9X;XC8C590rec z7?u@ucMDsxQuR!>te?uEFOs$mdnd2{qEW{=G=dpt^w0IfmW;kjwB7S}Ek3^Wi8)~ZcE{cC(}PwC#%Q&-SElEC z6X^IuX}x1!@CmFz^!ic#;GU;MfnaYGY=7e4D`mPZ_xT$Ah^|r*A7%!4c_JQ^k4WbM z;+bugy*MMA2W&C~Jbf9L8aQqsBs}rqH^slWp!VrQg$B%;_dT5fgSqCx9LsDvTDVq^mLsMYEF=RA6L)d zj~h}~+odl*epf*tz8j1Xh*Sc@FKgoY`T6AoU0BHS%8H@o7?4&ZJ156z>yr8UYFEfC z%Hx;z3RC>ejG9FDohqUKnMJq(Y7jA4Is$ zQ4p^wipv2oLP?gI1sd%4w<#-z3st1{=bI7KS@hp@s3n_y9w--AP%wS|iZOwu@1(Y} zQpNkNA=A*W2b?xa)jvbAXDRA9Sh#q2?qPX3gG(J3+wl8>&VM(zSTyU=1doJ}r!pVg z$#dm?2po!!`|&yKKbmj~>tU$^(4+6_8&nkpdYS7FvD%0~G$oB=$kl zD&n3Z$eOS!dgrDFcqFT9FMl~)4p{0u+?DqlV;?SbRu2;YtNo-r?#1j>Z{zXwZ6-3l zR|(VS4sYD5YGiG5)pl|i(~oCpw+cZO1p!F7>8whXVMKAV0yiI3<83CZqw~j5^P@65 zqtMnQPFN=ahfK|R`)0(RELrA+RoU{^48-4d!HbBDC!Wr_N~@!@^GA3{-NpomCW=w! zkx5t^AE?s~jcA68|6q7x4TU=Xg|&l`@;t|6RejO=m7RKk>}{T%e51;13hjW?Rekdl zhm;z{vOjb`8r6v7@1<6YHY4mW&yU2rN%?9qC?8E7&i}!9Pi6~P1*?co2n)^EyWl+c z^yeto)N5F`<+E6o=<-v0`<%?l?9(27)vnYEfCd>Y&+xVdTP&W6(j z+xkqy6sxpcIPFG8kFWoxnNxSP{4tDJIzQKn=_;lX`t2|<&FVsj^gpp5wq`jOU zuODo-TCS%wD}CtDU^W{6rQ6CrBaa#Sp3zn0=b6Y~7Z04^VIU0gC}&%bP(&B2Z7WGvv*NoTkP>{Gfr^L4<3#Nn4JRCFu-WR+ZwtzJipFFx*91_=TXVmczLlm zE({)T938!FR<*wZ_w`)!6t+OUYe(v@ImM#ytVH8Qp(pw;-x3Fs(Fy*Ier!_YYVEc^ z#4BkDj+gc;+pRz+c>fAgIzMne3RMM6r@-c(#wgyqts~?U)-dJnC_=InAnDiZM6pbD zy$=_{LIY`cnJagm?1#GojcNYP0zzUF5b{Q)LIx|9ONoVlm;<-ITGjAe8#ix3qfRh+ zsBbTsoe!=l8E)M(4Vx6k3$$Luwvd@upc?@csXo8SsH>XBb>t%&&kq=`ujKL6$yYJj zt-(>`;e@X$K*+GQK>l#wzdHBAI65`K%y}PR>RL7fOzq_`E~iHo-tJU!a1Q$3>HztS zal+%g&`{F1!~%GNAzjcMr#rIJDzN6IZvEf626$i1i2v=M$V!qj_2Np;vYsXf3daCHs?f6&DFBN%IXybVLp{xDAds2_q#8zaoer^ z2+huWSLuV_RPQ2gRe(kv@=_RRfiq7V2iK?Z#N$f-nD->6Gc5($^XhCtAP2E`ZmJ|Z=$&WU)A+fF!Y5v0C$Ox?W?e&F(YZbPhso4<~%(Pnr?{NQ!F zl;58%BmKN&9h(2vjPSca(1#bvObELCzR2YILg6XL=pPDH#0-RA)4Ue^P#|M=o|+kg zA>L82aeQ*oZZ2?_Zrh~pz{7%9sLs^M&-^1pB$F1;ulX@Bc(tR98mdwDYU%!nKXYqR z><=OYjhJi5q9~(sT&2f^Ubri7MVG&N<=>;gHjh1W!4z`msQ%{Km>l^Uz`nSv66ltE z+IHly(crgig11VEuda4@Fu6WjQ?jrirtth}-x8?LSZbwAVrO*fvj&Lv9h=;6{U6M9 zOAvbnOqZWB0%VSVgp?b!c#Ao}{{ny~gQSJK~u=CyCd3zprO#-8P1H)MZitYxirQP=(Iq1xhB{B7mlKSDU6qZnpH-G@64L7N2>fIdV#pMw*M7K1|? zkmP>7i#*8;S@Rs{3KP=Kq_s=G>~|?;F~V*B#X)N7N=5~nS0Gr6=c33e@E7(!N2OniKt)uBc@)atGCtQ!_ zBgM`(KQP3YHQ<9tp>aSAG-h*aAor_b5(QKH<=}&U=<&{9Br0atx2;^+@9tu4EYDXHr zEE#Z5Y~J=lHjBrM_izk`*Lh{fB?SJ3PxlcnScRCEx7Hes-staqM+c98&_%^!;u8m6 zk~apXOqCd)Uxyx_F-FDGMH-$3>z{0|z};7GVd~7!J3h>VZJwpADM&b7;a?k){hhWZ znd3?n+K$iCS%ZzQh7L3;toXX*uDiCURr~SbAh+lwR_E6vACZn7LWZtQZ-~H21rP|h zDLS=8RJWHMS2qt%SNaa|pyJiURJdNv=P8(R&$^e8r2U)uYD)nsHRkTZ3u1%0p@-lO zbVxJ7$!o!*33L#GmY-V-7rJS+3#DVLv1o^n==*$rqg1HEF!1ET-Pk0ZZk)@)mr89k z;Rw_|tq^@M+eCLb-)QN|AU|H^$8NT+pi-npy2pJaJh2Pjl}6kIKt+$LX9f$JVcHXC(!a&_GP^CsAc8KP5vj9;Ku7`cEH*#_LY%%kHz z{Q2Ps?w#vV>+Xj1KeuJf%rl?yW*jw98@4%HQ_pvZ7kdDdUTd#%n{*_C&>83UYA5v$ zKe{v8Fq6%Bc}wGIkI0+dF?{8uNZZ*OG4jDmDRS&Xjm@a6v@V#c&wB7WYUXv@#NhO8 zO72}78yo#sGi45@N~z@?S*7QeK)VyQcvLq)e&fY?@{@(}Uh!Yq5c}0>EKb{5H2zC6 z?#g`iB^Zl^UG#yJJ8I8MQvc=UJ~ba-uv>3&aBqs;bzLELINO{lB7JN3FjC}Wh_Ti+ z&Ar1e&5xieHc#J?`?_bmlsI@tL>Akt^c+S^yRFj@aobwCdmop^Tzg!~$e^ZuEz^J5 z{o_JAd}!lvB>+ysFQg%t4nlF-abk0T`sDe(Sf}~xq7Tw?NF8c8%FqKMn#k5AIR*U* z)^77@f!LB*S~Fak#i9<=jcK!Dz|~~L-ZR=Sde&4a4qa!As}OP*hqZ1y;ILcm!$a!l zwHdgDW<1(wC!R60@}`t@lT|PB`)g^!RAM%Ff_m3m|uUFGMhDrCd^q10JSG}0QqS&|{%XuDMLdHj5&H{;U40*ex80r4fF8S>%g zSCyjRyr<$75U4`@b93?aim$%StmQr(X+v7eRrel;7kxrh9sU?X{OB{pO9S3~Q53eM zLXi~R??4q(?9eNJ{t;Vup=E++Fl-wA0vMRbEnmzgP_ zzPW=*nkouO7LpOw()uxQQ+HzndzQOJ`E{uwEDet}@x`H(%gzWf4sCywP=UMi0gSG$ zz!{7)_}E_P$r%eS8-)BSvF!0v2vVWsb)JG`D2 zPh@qcOBf`ukZYHAqU5#asp3^R3uN(dGP)W^=OuE#UU;``TDe~Vk8p_Xd+Y3VDsDn7 zIysJy_&3Xfmxqf1KsP66yXcOEp`VGl|8_C+t1NkwxeN@9EEeIg+}=k(?^aeFc6`|~ zyWYqrBAW~gp%)OY?oRl1c-QLtY7%qFTG2D5#l!@7ROGbvOv4k`97ex5g!Pa33}pI zg|c0O@_tcQr{xp__TF~&vWR$lv9M&+PQc0fe`D`0!>a1OzR^V^At0%wq9Og($(wD;FIyi_1&1p>kZqr9X@g2bgC}kz-^MH$Lb#(d4r@B!-l$5e>#>EaE7KcC z*nuOs9ob?*Cpe`9QQ|jbB#yHaR}F>zFOvY4wLE>AnDX+mQ%MP1(Ng+ug!RJLd{^)_Py=_Ph~o*r5u4dooPCB8NXm*<-079?2XoUJCXE3mfa+JhCH z=8v`ZR(1-SoGI^W2SOeX{O5(k%X{-A`6)U1pYWWRn7N<%ne|wpJQ~YJCnCt`(KuDj zBc+C^Ns<0PpJ*Y^u<7vMzAd|{Ywz*kTYHm!+4;DkhFr9=vQkg?oA)=Rns^o*d}SVr zs(Gs{8N8gFK5(_$i5uipRPI@w zGgIQ7s_`n=^Y&@lcgTH^bB<#GyKu!+U z+ZL7AKIY|x+SeZE_(uCplgQ|{nD4Eg(3jg!g{4#X=;-zeJ|oi@O;w86Bv9&=N& z!0cY=Vom?PcIQ$zIazKb#+2&k3ltUxlFXmqI-KI5>~o-%BI$cEWu#Eg9p!V1FYuzM zIqTuw#Ks3hxa^WJsZ&RX6vKdrs zZk{;VJ@iqdkgiP?j)j-FClfGjC)$1FF2KpMl#Rb#9qpgq!YZ16aKO>)nI2MDm^LA6 zmvnR*6Bd?pX#IDTMmfc(c(;P%kdKA(b*4Nk;+ZOXyKt}P?%AXWorE_HV*17=__fN_ z)E6hFAHL5m=cAryS=>vL_^{J_Bc3I`7Ip~e_*{}T$H&^I-g^f$Gy3q+eAPEg4(VG?mC!St3t0M{4BPTF(#CI;ILDVKeU7@s1gng<8C+! z*UUG>`P8o564yoTG_d3lmB^*{yfVulsw9CS`A^pA za^uzSOBjf^XXVcV6$M4b<{1p%9o~CXirV?=nyF!73dJSlUzk|s1NT9guCv1u6ws(k z98Gv5y}jD%l4!oNvO>duD1t^sLc7#bET{-gKI7T-KNS*?yIcS6eA@17C3IfuOBIT) z;`#E`U%KkW!h}-yyh(Ko>sYBBj2Ry@kL{ZWJN|gi40uy_uFByNT_Xsx4xtgw;DiL; znQBt+jT7U^AX8K~ojc){0(X>V$zkvjZYAI3`xSZlD^c`CFQ)h{nrJ4piWS^;kNHN+ zeDCU4mp>%`V0nM(U-CQsKfmQh7;2qbcA2#Ajs~#bX1h{b?V-JE|8!+_)l98W^bl*P zJ3KuuE_j!qjzU=}OowJ>XkegGHSz;9l#(b~AsPGLb7&J^=S{P@_bWJ6I5#p;O}C=c z`(Mv~D)BORH?FC8dxm=I3Io9cE%}2wy!@^^lK#4VdYKy6Ot0+rYSjhZ1-%Gc|mP}vZf?(2T zF9b&41WA3UM*BrmYCmbE$7Z0oP;xyP*}&2#qjm@QU?4h11=PKm%1SiICxA>U(~hIb zIu{>b^VznKpw}rUx^Jlp#!M$ZGZIXkQ=4cu6pg!QYMk&dpWER9#-%CHufD;&9;e2MrBwhsG;NLv{-Nph3BI zO|DXFr&h63*X6yHsjGm2r7tjo25KJOt?}L`5MSz@Ih}q6#gLVX!$>{Hhv`P zgzHMFSLaN8>M-5%?cwgM#|f3NtE-%rukXZOZ489OgXsn&z^raW{prl>+0n{PYERd- zOYcOVhrAam_7J!oOB-G&=G}aj+*|3ogQkUH#=i9$!c?&|JQxt+tz>;cVG%Eo3WvJj z+(aStx1{^i$vk-=8h=`IEq<6k_y-ACsipxH$Y;jwuekTMyq5R&_7G$nd%?&p?Id9n9FetG)r(;i& zkP{tg`Gcclk3E-<%3Vi_u90_5l$MooyPrO-@l9s4FV75hT;sMHeUx}IJox+V zQEJD2E-A@v5U?Ew`z*E0vqD0$mHwc=OTA2i@{w_UTZPe&yR{)2pxA+ldVG9W&7v6I-y=iy77D0$cD(= z@dl|C*1TweRaQU=RtWVZzEOPXTpW&JQ{HK+pTNUA{BrJ~aHKxhyUudmH{K5CE#B@_> zeXNes-zWV_pWD_D^<6)`(2rOrv-@j03R~)l!j_9xTf_27KA#^=aTj=`pW|O5Xy~0T zpIM$==anlk8h-*;Cg{TYq%#AInOs#fhDfD=K$AaMss8$5G5F%ZoSv8R#WA>x-eBT%nfNo2bmzku4fh& z*XR5M@AL4ek2HOgbSXOpXwLMn(f6r^#mBTPtv4}JP~#j?_ILL7^C*|Sq)4#=mcM$| zmzHTPrkL1(T56F*?fQr_9ZzsUMg3gwgC}CcqgiB3|KWVeR%t*|vzz-N z7Fq6^iz@s~1OFg{vm^AyiiwRq8~!>!EdB3I+%z9>|Bo%*{$F1)8fj}&J%0T5P2$gJ zf0#8(z`U_x@VVC4;P)h`(rjkM|C#icIR8 zj_wLX{8WS_Wr_of&aW@ySX^S(U>OY8z@j@{MATYNa)hQ8?s0KD3yYL&!0c{q@mZ!FnkC< zg}rdBZ}7rIbIXLBM^7zsD>_kj3Iu5_;`h{O(LS`iaedRTO5L)qXa5z^k8V2rE__FG zL8LnW!rK`$R|@E=--XO?YFPUx9%;P%kU#z8@HQ}^d<>-kEvwhlqkC#XC5 zEclr24$~2JVxp#+S)wWQtIhe@;_!wN^R?SG%JOgW-a(sDoe9#yqZjiFyiAQ_?Y3Su zUk_{a+@GszavvNVEO!2LLpS+qbGALls9s`IZzsj~FUn2TzcsT^`^y8;ds{=4i*2}8 zw$Co{yt8&JuW0;X;uoJh>JE=PjarETp!LA$oi+SE{HjJ9^-RShxbe8Z`*#zrX|={j zuaVMUeF^H$r|g&aQm*MYwoPE~E-C5!$UO*ldB$yd$OK>gsl-l9H#xcCcMjWPUs}ka ze>z*f3rK!~cvA{tFS@wG(m5g-UEKFyj;#+4!nc4a9zr(91rDV(_=?qX8jhwpVOczFc^K4dn$BmPO0IbqlW?09G5w71F4Pm?;je`&|@L!47=J(vY zgPVqjGVeh7a@qc+VK*pK$#zJ?J9x6*aUDTWy+-}NvxI~f&na6LTReMQKEUm3$RE&Q3kDBm8 z;*B31oVo!#Mwxbawq>Asd3}JGSvI;E^3p8&F=MY0`1<-Lj%e`k@Nm1VVy{n?5W7dO zbK6W|{^sSE0^pq@vK9<-aLYjbdCMDf<2wd*q5udwH!bKqLCa6P-r^1-Ky=}BsWiNky&!+!%QT9MKR3AlHPdPNf(|ww-t181(j9}s zly0%fz1ol|BN1k}RVisX5FYVFeUGu9)u=q#QFU=T8CJIavm?T$*$>VQC?6?8zL*=6 z!i`yaF0L!XyX2<_V6}ek>g;>~tpK33u}_gI!#!14x7Lpb@&Z~O8ZGrBE-fqNc9?yf z=nb-KC^s{1b4Jqjx)VRrZVq-b>Jyr}k)kKRLdr(J+u+^1=p!*1=Eiv~cA!EdMg-4= zmVC)?s|Kkno4*RdHoHhOvKso0VNiY%vsrucvy{iJ$!h&i89h|90sjWmQNGf4aD7XB z+WAXuHUisBgt{hwegC2g#FQEv4>!XO&!ZH4H%OV4-V66(UKh^vnvN6HkKf!pc*xC# z?w|F~r>~W*{cb#c*wsJ!^J-F(p@5vsLsjWJ9VZlwde-_>Z+|KDRhaaV&zV_R5&?Cz zvf9ch?tcHAXD7~ur#qd8K6$tZ1ECXgqT(2@jIemo>;2Jzgapl8Z2Tv2nS8J+E*JJV zvhd2>R!gehczNcHR?}acNvluh+};eHl#09^ZE9S>aNxrby&fRW00<@Tx!=Jx03eKicb62lkO%QNlD|eW@ z$1y@-dAOSjgGk`@X_S=jIIbY7kJxEF`Mw_o+D(8cHEm6b z?(gp>U3-O#llg4K;dKHjNww^Op#Q>cgz46;t1;zN=#;Xs=64@sJ}QbR|J>`6_C1rR zOYgpfzq5PhLn|XI+E8KSW7?h;6nql=K~h*hDMvRAmTl!ODO*zw-dF5P`3+76(fSBE zNX!|s%I3iqo7#Y3pOl4+;}{T=b+?P_~6}+WsEa=*5$65uPe{aNOdhx zT&%7YigcoT%&8OdDTMeH@M$`sO=6IcTZJLuxSyIz7}h&qMMx==nbI!mR6Wq+(T3l) z(IY6NbFb&*H*9b)j&%?d?Q}5b2GHAdbMNQq4=-k5q&JxMFgksA9?tr$+ya=WQJF1;}5&+sn4sb&N) zIQc&M(@x)c9P2X}7PpmCue)>?hmhS?PK-etG##wzDJidv#w(gW8K^3y@hjd2A8~AK zOt;*czUc--+gl!X7Izi7x(Vk%CTQ?h*p6WR&0bRkVS7$K)1dM+mT3M$7Iu;KH5o`}V#V<@?rt%y)O;-@d;}^+$5D zfzTbU+p;UVNdo&#KcWpR+8%<0-<^jDK>P{q&DZWorp@BX%MS=;r-$fxLPs zN)=T7pJ8BgEh*TI3Sns;__a#L3^@c6}dnSUME}QMaT?$%`AHo!H{fIS)qLKwEcd0pb z^41eBX;3f}Ju|~8de`16apA)1L%Mo$#*{TO$f|Q$xXKt+<9&1k?F`~ZuT#bs0Rs&R ziGkUd^-SF7SBZ%7=CillKptuskUx*3X7YI}hffBs-jz=|x%i`P6-I;%MO#DYk*rUDR}`h9nZaZ7A+NafDU>xIb%({nK`nz(-+xkD zbaPLOJa{2IEaD1Q&8?_6Z{9S`FLcI9KOXeY^S$ucAV7OhF)=avBfh(fd;mqbiVRdO zgnp%*N|P|kesb3Gg>)^8GWOLPKKiCYpD58T4t`LM+DFW}CyaLsM!;tL-Jxg1u2t*p z7H0>uxpSXre$PgKhh^-CT%qLb$=*vD!(X)w%)B0EGBYEH$#CJPQl%_5Ewqdumhx?2 zKphFdC8qSRVzYSuJeazQ-zqYuvoJm7rj+wr8MD-+w%a++Z3Fj8FWec5jO4tgfxGcBmlWs=M>Y54)+p^`l<5nskL?XH$x-ByL1<8sz9J;2byV+fZ9** zI)w+lY^TayJL9#4S%wCG%YaC&R#&rbZG?q|g}!%699&OD8v4|41hGrL#HHEt9248M zF(aHN8IR+>9vN3WX;px3l~&7GrBwTPT?#?S-*J8{C?IAwC&ymYvq;uCoUH!l+oc!G z>w0JvmBuABUALfJo`3rFF090Ki{mgbo^^^A&3hK@{EsvQ8;`9=KH-Dw65awQpb};;)J0yhXp%&4Hv)?4Z8^+43F`(TsztP zVt}+hmJ`BpXhi<)4qOu=d>4p$I*{eVSypIg`x? zDDW(9d*xkwizboR*~!hz^K!%wTh1I4=PiC4zJaq-K8O8iR5PevuPJAVIm`NHcwFt_ z?AGg@uBkt@G@*wC7=2oI8vEgg1=~20x4$30bcyJ-mm(f=ZNawJd;j09to?FZ!_W-_ zzqqS5Zk>*Fw?KIO)lJ@^)yDp;*8@P{BcBuP=T#H)Ii)cpx$_+z9csGbptLZw$;{t< zMveR-Oa>_9xcb+rrpJqrrKp0HRnVhNaZ&dS_s@J0igm&KK!Tdnv~eQQxoW-3$|;qO zj8cMV+b{c=ga#X}-itWn-%CZk5kZ%N#_4@oT535_pJQx<%ltqTkt87z00-_hYkGWq zNF&zq+;cHB<#1CD_u`z6PI6@!NGouY* zq+#IaRg;RXF!W3d6#e+$Rl}2|`)FMk?684axdZFlcax8o$36)_d1_fTq-}fDD)TVY z1J&!?j!{67u&hV zr&ZW}RNvQ-8AAr$fyX4$)BHa?g?k+*|M3*sTEaE5v0$gtdxk$Sf&FtYpG@7yFlQ2} z90!bC#i@{*RGCp`W00>`MknBGr+tqSNF<=tnf{lxq}Wo6bW5LQH0gHozC39-*}HUe z@luWgF}?V;XgXTz!z9>7rm?WrKf|S+fB8=$=iRP+xObQluIN%F<7Iq*-w7?d&&iB@4&LPQ*d4@L{hbNReb;6-|HGtZ9wt{E9(~xS;kErdac`Qr zT}9H@Km>&T2Xy!t4K?;+>tiT9qCA~;fen3W>G^tZ;f2oGmOqiYB~LNX+~a%a2q4W9 zvDnPoaepk|=V>X{&*P?iSRk3hmmf`g{&8+J+2@7L#u!!JumjcZmVy>v zPDoeXxh|K-Z8RaOuT>3S`hi-Fq^!?x?tgiQX_Q~ z6$yPVGqf>+$kAA%r<=GIO#_K?`IkRJb?@k^*51WOqm_`zSl z$Y3^)hr@K!V5fdT%Qj~^%MCU*Og@{WJCp3Pdh{;4-uK@rXu(E0^X;*(x7S&flzY^N zGLrQ76(S=R$}*anILUHZx#gXDg}o&LnV2?rm_>aKDE8S4{v+nEp{0AaIgG)nT4^yt z#rGUVo(3G?z8(k(k5d z@2}%?J=00OB{;8;lhV^)0+`V%W~BM%s?VBU=92|*BfL%z$&7E_Z0H=5U5vh}1<%yV zJ)UnY23Am>cI}{~oFX}8mkZr3woRIuTd*OptKw3`-5YFlp!7dQ%jGzY2a63O6%K1fH?6MlOZYVsa7|92c98}L80l}BITjiQ`l z5EdI*0?+F(f>Y4eU>}tVGAtQq-3C`C$pI6 zxuyP7o&BQBSu$%dCeI^$*_XTWVHSP6`VDVJLsywB`t)1j3h_@)ILn`SDdT$OY7P!O z0S+={;zh44o}iJdRU%q(dOH*K=~Gr?-xap9%9*g7zpTR-PmT&$Et+lF%2}G`#xav?DX(D6J+-4R`8WG#3rk86(8bws;a<$; z(N@*9cJ1|2o5t4&+T4Ctx)5(PAEX%Xx(#hrCoQ0!YT5;Od*>XRh z&^xi;I)3qf6#4HXanzr;do)9b|F8cP?m!&;?-ybJ|EvGq6#pMvB1aO0F;Z&k(2*h& z|3Prt^QSuf0S90)oJ*DxqW4 z7EH+SIY0GENJud4e9y|b_v7eKy}!i7l2iXA8%;9k=R`_N8<~@n^P&q*Wbz8(*qm^TIJ!uSLPO(09WNlS~#$Y8P^E4>7_w!Oa}{O%ojU{H{RLVCnt zuEv+`?OExGY7Y{)*UA@1_si|ZL*QmHIxb_QHpZ!$WupR^@E=Z$s|K0hR9|WNuVEj4 zjn@j%(a{0Hto|(fo2{%YE}WB%1=6YAHC|p`aqkPHp{IxV`1m+1I{NJ92ATSbCY=O{ z$NK0?EWFDIRBppC#SiyB!vA+PvZHim4Gd`A-Q8)#ym?R0&fu)Cm!qS)ykNRK+THLY zf*5G|VB&X4QBiSpeEbufe?}wzLXh8KD(clMY^5~OtC9y>cD*S=|9zJf#Aua7Nm3F4 z$cRBrP5qwJgaJ&`wrE-b$AwPH<&~B2qmzGiorg1xSfbv0_!SP*xJXM&%gpJ~E&x3| zJk8w453^sub8Sx55=ux&fYtCxyFlM+qLTX>zin6q<=xKxHHb+#luUYKtjx^P64%7U z#M8&88RT3s@4XIG9l@Bfv7VkDbi)HRHPQzU9+=wN62ZOjTK~pJD<;m)_h0z>nps(e zLD(~~vR)FfpO8HIx3KH$hC1F~d!(%V@ZD|YjJ&+N9^!DurJtYQ_U^89mGi>P@-jZ$ zV%eY#%vNY}q2Nzg>$Ai%@Yo7L5uEYuKJ`#(27PoGRd8T<5krTo9Qo;P3?c zl}5<<^WZsF1H7DBTZ_!oDS7~Xzp^W793CEC>9$6NiG{WEZ$VK|NT~78pQ~loLzu{` zSFf1F#Tj~gd(qMk+Ce$`XMr@pXA$J$be9CuhI*B=g`lwT*Zq@@*q1Lcw2O?RAqSYl zFw-r!y;^s&DKa`bdWSvz?b{pT=riCNhKH%(dd21CAs{+DF*he;V7Lr10lA4{Rh(d~ z?PP{VUb7P)-nysv!+j$Ok+d4lPwd`-yhn(-YC>DY{$!CxrEtR z*T&9#`-9NS^zVmJ|NdQpz{NoVLqkKE@B_D@HFjX_H|o9d$>wDBlPZ|M-I1cGpm4?e zU_G~V5)uO8)>N&A_eN!<;B2x(ox1+T;m-VnM~}i_A5Rw-4LC51Lrlzwi<_Ga!LN4T zc&y($G^B+vb*G6LY)<@J=)ysoCaPS9>~)NKlK99G{8$FDAj@vKaY()PWf-mYC>mUP2_q(JWVFiV%OH0SoSOOoW-h=I_Ui~K`u3$rU-@60 z={%=rVPPpdzDNH(Yot4NZ|m9b9Cd;|lWKQmE*>6@Q5JCeW*dX>njlEwHsEhUH^$)S zpST2Nuf(Ys!Uj7snhG3-m88%#fr*J%EqhZgX=rF19v?q}XRh?W5cl3&e);`b-vji# z?EVS|6JD!+1Oo$OwDdW7N=gdXlNNj=prD|@bF%9L=O?(!*N<+(H#aX2d{Qqm7DEqH zXaPa*y*nY7=_4Swo&M|M>>rNzrF&C1HwFoih)$3+?yin(a|p2$hDW%9yna38IFiFd-g2!^XHqVTmBc#kPAb4k_DJ~c`4*#=%ipA zJT0UpB@KYQ@pG^GVI0h>39%U zth=I2e5MTu_f|0^c-Wu7)WRY-A%TWY`~?k*YMM+kzun)lJj1GQ3ICozvJ8ujZEp%D zG<*8=)#^~5fALADU0c-EO2=8ua@#+FP{lh$2~Q8k9p-kUpbw#S20qJ^|7U8UxH!71 z($Ue)EG}vM!@S$)9hY8Mfd{CH{Nm;Y?D~$>Nfr3JWshJr> z=KA{jU+qzXKqkm39r>5z3uh>?rJ$}*#!7holPTPUu4&!%(iyC*stQLP@2zN5{eYTD zKECnmOKf))WcG+9TU%R-!a@QniNYFgSocHSIq1Mtybj;xj zEcV}I(L?3xwQFDG*gY_rb|B=pV zcd~%$YzSvPL_N!M2dE_IAq7w92so`uv`O+{^{GW(XmoTmx9hSRqE;E`u{K4~d;0B+je*#WCp*cBhf;an9Ob+NM(QBhH$1A>T%=3E4s$8kMNC@|iVZ zXT58I0gZ()hB>6!< z8*A&P=4LF!@p$if$@!n}&oG|WJAYku(}r~CLpn>gi9 z@G&VVDRGu~ytA|8yd+*hKNp_kd1tk+85VJ1B9ApRBo528#7Ii5hX?}JcLUPX86gIg zl$8DsQZMZ9_DT3x#?-OElnbN*+5;Nx*I+5NVm!Ng;z6dvbRASYy_e&sqjru_H*c{1 zhlnY#0$9jeV$&%knyEMXIJ9%HVK7ne&nW5MP$nCFO~F4ZI=a==N=D{Yx7+GS(S`h2 zS4Rh_PSqo(ntidVhZTI?ZSO zn>46#(5I)fGcY}U+6QGoUtV7B=;lUwLV!T#C*LJlA~y85g7r5re@adtRtLs@|IQPz zUf)%-1ep6k9X7bn(opwj7|zonL6`vNZq;A-_1jSIc}WU~xUK$Xowm^_G}JrozGpo! zFn}Pyfsh6`J{x(u)KZ}FMWt?^a)G?Juif^)qC$BXfWAU*O|tufHms&WB%#=ijR2jg zI-EW#SpKOjJn;6d%54dEA3&O9u~%tnmkd3K20jj^pE2REf8-b>W?3>(JEY`ATP+37^w ziJ+UgczI<;3XP`M##TG3Ss4%~ocG|69bUkOA+ZH_BuMu6_t7x$@X+&BNJt&(nU6h= zFN5dt9xKw&*{M*lHu6=Au&%CdYoHm0nrkI=aEqONl>sL^C_$4&>qM-Qfyg-RK zl{uJILFU~!J6cMgvn~rKXZd>O@la+@2p!W@9QrdIZ5S~`><9u4C~0YFS=T2n;4sDRSe>2;6AA-Zuc3Rh zfTZ@TG=RQtdQ_jp#a{zI$Zd0ZkyW`#R({I|pR4>h4J%(eoB^@O6_tQf)scq5EXCHL zYA6z5tcDos`S|#hjqe{E#E(-v9NU7U>X!^Id|8IUXP?V!2XaDzH3_i6r)cd8Ssd@f zGcQPh50D0+x5dCaFo$YAcWwZo{IYE*m5yg?X9r$ZZ;F(+_m9qRQ2wB(a`x{_l!k@j zb2IvG3itux#d!Lq_H}luS{4`Wd*2BDurLJIGdIEO{8VUeZtitloYrq6bwyWKUPyUC zpX6ex(EL3DXXPfK&=TON+WPhYR%rOFDU$f@*eYg2GXs{}tsQ0>ULpZryNkJoE8b9C zPGMBCodBI`Qz>56S+bQ`0PgE7!E>0 zLc;xMqR>dhu-01(4%#7M)V9H*b3-%B2`L7oEtc^@s?b0evT8_uz5mC5*RPP(#%7+2 zrU*EsRW@2e=k(|4aEhWajv>ro4GawI2gvGFwgq4Tl;Kv?rZmigDz+K9hy0<9h+z=Z z-&7oi90HxYbaK!Y3g=rvHQ51rpCaUf3qc(e7e@(C|MAP0(E9UZ#mz+XpKouWqr}mX zYplY76IHDwFW>k(R|75{@c^}^QTm3q{ZZy)g*$D+uFFgaX)lz;-$qC5kNne*&VYV= z={R1VnSo#2U}hdFeD~(f$6;UPWPW96JI2!b`L<~Cq6+|*B;Izg?#FFoj1MCKtxqJ* z7Z_AfA^5=a71nOwyorxs6A}_ux%?lHNU#RL2p^(pZs`+LdM>*$UNZ>zK;UeiFJMOC z{e4Kcw-C~7mDI1>KR{$6MexVV?M|3B4rscE%;6ht6N)~9??zR7uZtn_IK%b-=7Hz3gVM6z`{ycdg{j{ZTB}MRMmg0vefGr=v7|vEs7T5L3(=K?>9z#EnA%&&+RdXfC2qZA+*Y;6w zJ%T7K+9~7IOEI;Q9!Tl2;kdZCpvfMPj8V}1Vi+avxlDmD1_b2F?vjGM0E2q@s+PC- z7VfQJJJn9mOT@*+b*tSusMvJ_K=)ans3Mm<2(g4VY7HT#h4n}|-~17^x))wbslt>1 z$zjffmi}b~Eoz+Zv@=@4c4?MXxT{9OV&moV@rAdpU(R|Jls~}wcDSqObdvg z2J}EEhwmqn6F^CWzGQ*&_9}_brqgK2HysaDoLA6PPk=2~x~^OSt`G{`6{ACih1+RP zCW?j|6LJw&#vzJHOuL_(&BF^|rAse}n4sQR4Sv1^7$PE`!{94jb!tiq>6I&iXcz+E zI-tJ3KEdP;3+t>^&DfJc8gNsk14`os)0pX}<@OF-evR-am0MqnL4gHYJ_euryvtIk6>WqkR9 zniYmkgB30XqPyc&Xy~WMSqjCarKe0qp?=*5=;BXBg$tts(k`4mb3qCs+e?sF;CfAi zgWnoVD%)CHXF>PKQU8RSUr->SR}YhGms{J~ZU8j~Jd~rEd&T;9&Pza_;gKfd0Q*7@)PV7+JYcVh`fxHW`Mx0TR4<#nh1i8i;FIM1YL(mN7n@fX#jx~mz4!Wk%LzMn48Nw z^&umcAI z)OhY-Dkt$uL1I;0mnMa4|K~r-kC~msA1^-?^3-_zxJC2E8%_n1Xd(09-gH}mq|82C77bJg;>ica=-U*XiP;oUs>s|t*SW#Tsanoxx?7`~ zCMpksf*k-{&w`@9q-5Z)u`--s?FzLid39uRG8MVA+5YsAp<&8uoveOAJnVpVi!ePg zk@Xa{HSO@;7(|Kf*zj<&6Jb`P&CN~AF0)EJ2P31@CkX)5$!_!U4RuIdzy6Yro}SXq z(be^HN(!BRwR^#U-y>)z>;Q;{;AHK*TL=XsBWIof=*l;4lu&2tCtJ{jnEm%3#g$^8 zZ7OPNb2J20Rt^pg#RJ($O-CoRqvH|S!3}+VB;cF`S5ydm`~IEV@6;2`I{*ox^DX?# z?9x&wvs_GvxIu1CP5^uw^j8D{)(=&M&JR^RjovE4pu)Ztl8F;My4Lu=-QArDm%i{{ z!Ys;3*nnIh^UZ}a`ZZKc`$axW_tj2`noI%EXMV8$2e>*OVhTLdbF*rxzt+bCEx#SE z<`unu{Tgiww6wJ~_4M?_@ZDYVzd-qO`=7r+l)t#193N)_!hEEtD6OPK1PzOmsuG%$ zb4Mv%oF0}Q_A}W|{-7+XJ-7ygYimElBTj#*!ok7^6I=v{QPcQ1t(%)0umA)( zMfqO{L_|iG0J8!7ibo>0_u8oQ`bHLhgVh?Wox%qVPww#SsL#KRr%yB}B7)#_yO~7L zb(!=(8+UHlX{7W=68Ctc;D%>YB@4Pl#Ps+o*_$`~MuQtTJHFhA=KzM1J2Nf(SE7l7^19uD9Vs?##6%>YJjryFa9>Z*UGe(cK}JN&-+iYKkMRZR#qmh} z%=&u1)h|$z)YQ}vRaI3Gb?>BQh=G$=vSjO`prEil8_I}@)E%$hZ5^rteMSNx9}c)D zaQ|ULMF)!x{wK5ulc=a+elOo=;QpTAfstLiMv(D*zu-EQG1xG)6=Zg_a`xh6qXOWN z7-9+}e0jJadD70^%?)jX;S$jV07{&R^*?0>dV-CBJ=O~7!D2QmuKHKDax>Ip3KUsVcIvv&%eqfQdj(P#ucN%)9{} zFcxx|Uij7VvG*HU$os_()8a_OxI=w3wXr9L!_IL%j}^fg-18iwDYtko3#iu)YY#~ynXu?K>)4rC8JefoHGJFj=yIVvAykxDCe1FX?-tf1j&A&%tj}&N#SHt7;3Lcy`N<>*d+R4 zQ}Nq}^Rcv5lfJ0Qu&Ah(zdzrVxUM{wk(GS{e%aR(Zl3?^NzlD%<0+VQeelagHuhlk zRvj$@^TL=W3Jt3;CF?MCqfa)gNv=iXYH4XTDIbqRXQ8HiivI92_^fj?3&`JecX4U{ zr7B^^Pn_{Q?-g|P-^Ilz0M{eT82nG2^023%*AV3!M{pk;F<(jv%7K2)Ih6SfGd!|* zv6}mZonZQlDr4T@7O;oRYCtz{-@W4oZ39a9Trjn91Kd$0l-^{+hu#03pWqRCR*Er; zwDb6AWj_t*#$|(%h~|S`ob`7$eQh4u+ZQTL;4mGi_4M}2`uf)XiTfOq_w?3HCof>- zx>e3ue2ep~)Fv%3{3*{EJZCWS2H%z!J32b5QATsjLViV$xYi3aeT6y%CeNQa-i=_* z5xH1~hmceBoQ53hEfu-kE<%ZEL%<@jovxz?xuqREzMYvMk_g|}tN*j+wOtY3W2jjy zb}!2*U66(X{HIvrZ13HhWZ6$IBO?hB;D1=R=hybV*T>4bI()%OKmwXWNj^ahe|inQ z?NPZM?ds|(!W)NrBgdd#Lxf5zdnFZwi-ACESs19^Q8zM*hG*$I;K5DO=B?iwRD=O+ zJznj>3A6WgOde`$$7>ywqO}RU{-@8N(Er&6omM-`L1kuZFpCiFah{>3u@LvoiRW6! z|1&~}c9W?*3t9m+O8zo4gSx6k>@nMfQ?uJE65$6V7YP7qcpmI~E#-Oei9}S=#qIuz zgRlsY@CKkGHwmK*`S&Tj&#wS12?pyWC&Y#oLWmP8A(F#r$-?ciX~=F^*}tq-^7=LD z--U%n5M(2%IosZ|YLHS;U?ZL@Ss&C43~W5A3(%hW9BG}7i7zDpg0E)NUQ6^8)YJxX z=SCir7V4E`%(95pm0!0?(*6TD#&f=YZ2})txwrj4_@gA{#}DDpU%t%4*psxj{yK1@ zXpxi*@{$i|nfcM|n%Tw0{i~=We>uOq3#R&bRG!3ub6UUz#NfSDn(lutwC7#R-VG0o z#uK^jUtt(Bi;B|0*Z}6t2ub$lfp@LT(q$kjvGu15oVqohC1K{3e3GDdY(pN#Lg0yB zL81%{d+G&@Vh&m$v)-&&gqp5oPc_qa=5%qx6d-MsP)$`S>Wot8>pAwIgQq-6I1l>vd96 z#8SFHLvV2L{fnfeYtY)HP>;VU`k{GLn(u+0o;cd)=>L6wo_cV{k|ulp*XU>%cq*Cx z7w2==1Ra?`-$YAdHOH%YRKAV<>Yf{ABmMRM{->wAeH(TJzhd5h_%IJQ#Ip&a_USB% zzhu4N8Tt+HBvK&Z=~ViT8>i@t^fsrb8zz5C6_j~}DoutRyf8Bu;VQUeX(`R^yo zK!JgP=>BgFG@65N{WxUJ{o>FuU+=rD6UGPQqIPhScjh8F%Yzn_Cu+RNW5thiTAyvv zWb(!`fHWr?%Mh(&?1ckbMa3$UyFqalUAI5}p2Pps)|Wt2)&6a7QsF6MN~9j52$4|9 zRHl%5Dnc2`ETIrmNyf}7lCc!Y5C@At0twpxp2opbg+ z`@Z-6yQbg0-{EZx1ka%h%BAfeKClq24#n9s0FLxT#Mb?)Qt#jvF%VQL^-b@rkKEez z>)8ofMXrb-Kw2Sj@hDuwz(rKc+I|fY%NM5>G)FmC6rkIHdvVjGuKjB>XWihp}Oix@uc zox?GQ6B3c@6d_4^08kj`udEi3!?JF*gEkMVD@`hZz~IHjIcHv%?CAJ-71CXE%RX~! z+o7X`-S-cVLs#HAei<3D``{)gD?5VF5?}i2)VRYo1^;_w`l{8<%`FYWZ}+ov%u>?Q zt)Rp71bBsTz0R>?QI(mT7E={JnV?T-O1WyjAN8OMU}IxD9waO%_&1QT%3+MaD9A$L zG&DN80IW-fZ8(7p&Jk$5f#v*Oru&abY!d%IHYR%O+%a$OL+~dWANhhnA%pgqLP_Yz zv(zE%mVeLk^P%jLF~p zS1&<1A3BtnlA`wteVXydjR8jK{DXp74eQiP(dt$-w# zR#LZ}&dRSW1?*JKgh}GHLjBg$v+W^|*XM`>q+Sa4(>W4;?LfaTI~Ls6YCHS_fj7#N zMQk`9X3V@+SYag#J_Z3QLiiFH4H+NbkUE+=4wZ^h9CGn9mQ_eJ>pf(6wC~Gy^$m+p zG(c(HcX`eU8=TrU=G(RK?aMOcJ8)nN#8*YE^pac~ksgGe?K1!kZKS(c7=PI#EZpCQGUkxEju}yguESUz-cwY zcCwAT!e~_Dx%=Mh07O%)p6ko?xi`zDS4w3GERRC^6BHJ%0pV?dLO~t^d5OOuY|})j z?wmUzX>zgBA`++ z+TNkq;QCGf_)(3yeI5vpcbZbeBR@dZrCBe4c}LF&@>4cZU|>e3b&Il7G1)(9b$v3WX0QvMUqBw&c!)zAK(oW zmPPVRPM+L)>eMN6kjP;-f9`R_#bvMm@+`KHm!N|J!a^)YR3i6r4AP^91|#Ca(<2#? zR$m&TIb}Qq0g8gYe|NYiowl8#8iFjaTtfMpLCL!tL6TYiUAwGW(p9{cXFEx**q$$%EOxS|9?Tuv z-CWDg-a}|LM?i!k2$JLg=_WR{v5{*Ph^$p>rh0H#m=g3+aHw>doAzNBiZT@>s%W|| z(ag9;LOh)DzIOZe7J*&6^qrk^r<(dCkoZPZEYXW$0jE3 z=@mF&psc3GP9)0vJwG2yNec-NkDJ&(33IB)R1o3^CZo~3#3GvR z7>GkMGWy!uOuz&Q1aJY zP{CDHRrl0;UbqnBoYj60HNHSDzE#TeyW!h^{xJ^kY4s5k65>T-FJt=1kt2rY=D7~S zHAQsxVe4v~*Cg$nP^z`a4CzYM`n9P3fY?~Q8+|7tkv*AU%E$=EuO=>cD;m*{%gK! zHJNF<>@o7SPGZl%q;l@Zb}&{1t$zr8Kp!S3kXj4L9rA^ELKy`AsvrWJ5aCHf-7|n> zc_Sz&RAMhQF2N~`TbJe_3XL6rPLcP|vx)nxgMtrZOlZf&rvJIKPbm|jcx}X20X>j%0;KkQ2A^qp|2Srq>Q#B?ANLJlxxvTykCtn?pS(kyUr) z#WPII>!&EcUAfdj80Yr|0Bq}{?XIon=1N>Kr*vYxUs9gv*rf{L=TE+SBG7F~jikB(qtdGMYc3NM-(&p2MOk&*XRGYbJP z#Qw<~__M{Xaff2lL>bRLJYdC@S~ld0^qG^f0@yJwIav$PEbjh&Z4~%ut!ZErTQ{>k ze{(TbXxt3udg?;Yz!2^Cm5-I3eP^{(Uq#o@57V8~2)r;JM*5G#Bx0Mamo8<`F|C2M zio(WE#6z$A8qphcEln(X@xtO`cDsX+uyCWGi@p707SnjT)`Y$teCZhdumxP~{h`3T zJf<&v%bxfK(dNmGBO;lpPE(SWcGdXF3EpZj`y`PZUc}N1%S_0xI}Mhr&%qWF?9x{u zr0kE}PlmK`v3G=OhM-`7sawqm%Vs4f)~!wxIx(euAV+2Z9I?o!d z^$y%mFBaW=IR+)1Q%9qxtwe7p^*{}9LF=xBis#ZZ z+YK_l3;Ds^n>RnSvve=$wV-v?(u&}`iPE@zHGd9GzX>A{*z`?9vHE|Ys2W04a zbqAbtsfM^@4b9svl}|`m%3SyoJmeeFFej~QlcLliqVoGy7PTTv^zE5Q**mv3_P0o< zXJ>1;y(eqUosMeV0Vz-F@84WnL(Uvt#ipW0#>QNA-T+SwcBYkQzCWu%RSWeH@q3l8 zIxzd|7e671IU0M>7dQXewS5Q@KB~VRC|6?;KR@1Ot-mcE*o0l|<{Xu%2An|`*Sp{tLZ^WZ1eIUK6F z2vb8LSQbuu23Hk@VoZ74$zrCw&TKK3|0$k1i^{W;kc3h=h(zMf?@hgoV^F6s0mw?V3xO<@H;~I5*KTxBsxlb`1Nq6qpZBk>Cmg3sT!NKwPq4|juj}8VURdZcE zATQsSnaqYTV**2RP?5(52Lx-azavMi)oowu8CgkZp?=Tv;>CyD2V5T2H@|tKx5?`- zM#jKJ08>L!yz65A7UcY-c-R0*0>=W*?7#;<1x6c%lz98MJBL4VUfVY`WW}mxD)nje z*L;)h$6wbk*|Ub5mrlJ{1ac~wlLp=nqB7UPxCKs{o}PY*UYUWZsYry7NJ$cwP1Tzk zanowc)!!y|l3Y9IfgWiH5KLWt{cDg}>lzVzf>eu7IOskcZtk4(^l3dJYXm_GK{A&n zpKy_Wld;?hm7&~{KK%9b zK|<$v&@xnxO`n{LEoDxvCE(f__9*JZZpu__xmuAa>t3MlwI^bAf za)ptBL62t*_!mL^x$)p|l4k_`^iU_orM3R@ROJ>JYl-Bb`+;a+&kRdT&eNj!)E za@a0aEhamZ(yUo0*j;jI%7O^c&Z@~&k2U-ZqQ4c{9^GG#xzsq1U^4*0uKW(GH@=Va z^L0zRo<$UMw5DYfwP|+Fdbe(!FbQ8CZS7n4M*(m9eNW!cbXKyz19X{2)ddrFvzQH{ z&6~60lad@e4gAsl#+*88`zgjC3H()*PF(;b%V;?>#2=XQJOYO7qlwl?a-6L>p#&^V z3_xfX=bpQ|emi-7i0a!+Qo6;8X;Rf+=Ep59`Nm$m;7$fOAY4`>P07p4O9EW<{j9)) z3>Ak!UHy=>Gtxb%*`JeCh?^@2~xxUw;g8mQb74Czey%l~;9q!S0-S!MV zl@{+9MC_S?-6b!?PTnjjDd}3!k*D$l8nFCzwIf+aZl)GUyY>z8$7ZssX4f85^}Q z0zk*_HgOZb=O!Har@I}h3|B$AVt4CM8kiKzl22A+4u?QvTbgjaqJ-QB87a3YMH3iP zs`@Ff-=kTg8hIdVP=H7>Np^)10kN^MU0x1+ApcrzZU(|}4=#6sJr162Uir1yziJ`1 z^yyOnMglEQ3Fl+G10H)l#0XuU8xH}3LQKXQ&0D0;2#^_!F#{{hRFU2nUTJUNpCEpU z!L9*#hlYTK=mk4Xw6HF5$7_+g#2GAZll$3(HWpmm5TmSUa4n5jLqVqG*9s?NRKka< z7{xt!#8%`zh*tHVBNg^8nx;v8GXTu_8_Fdh725k;Os;l*@Kt}_9Q0G*GNpjR2TmK% zJ4_K&Q+py@zmos5k-RCwSzEuX0Bwbe-G8~NQ~iL0G3Z#z!d9z=n0%``#`fEt9jb{- z5M4Hb{hz&o=1OHF1x;fnsdJD#NbXNY1R8cm-U&}hO3E4OYhBz~c^u)rJ4=T8Ltx}Z z<;$&M6s&^MF1>ag>n`UbH|x|@xVEdS>tnIwZ5FQw6mgf{gHKMCwSM=&N*Idm%*f!t z*&FMMIXqgk4GdChG|V(P{TZE_S8NCKE8UiKptTXI( z@$S%v#or7d&wea%-GL|)881yW*xhq!PV#`I{@a@S&QEfqDAp3aJ{kX+`4#KM{&?o9 z$HbU}F84{P{V@0~VmfM@3#MhI(1X1Il8Vp#9B#5Yy!LV{RDMD$Wrs$}^F%&r>5ivV z%@mK7G0Rm6?u(FSLMKHl6KB5eJi9H{dRexq(79yrQ@_?B{zAlzF_@iE3uJL{KqYK z&vYjx%sDeiwYY!GKSP|>GmDIfC|=xu{3zFYM4aCR8F8=jqYz}uHKTi1gyN|C8Pa6v zWtN@`S2E?aMa$NsvhKeS>u=LrV12V|`D^^7rDJmz3gf!e$Nm8OHg2-1_F^ZBXSz+t zyXw2I;L_Sb;8DcCf5F#vmmm7|~o;n@BEwDlp}v z%uHpeUbyF?-_OGxt~Qt9cusxO-4v~z$Hc(E>H;Xu*M$4Oy%Ze-TLz&SZ<;8XTlvu0 z84KHzCN&^{Xpb$)$vUba^hS z=57z<)EP@4K!LaM6GirIyV{EoII9bpK=b70$)MetuHZdV{@DBhB14@)4j`ox(a_Qo z4G3?)sidGF8nQG4A$qbPjb@-TAwPzdA1)!@%E6AGD5kZ3uf4qmBq2eM}&1Ij#>Kmb>XKieP3;3aoORdJm zZDQTKZ(j>v5)!cV#HfeNvom5A6@smi!NIy4&p9kPj&dq1OQIEmmNYlAf0C($L^OXj z9Nx1>W1&j|8^%Eug&dcJ2tFRJ{_jW2n7dspXS z?Ne4J6Tj5S{;Gbb1zabX1boDo$;rLnKeSR?A%%wzYtGqOD=P!M(ghTm=5}3>R1<9) zCPUnT$pG2!@&NPm{=j@a(t@T&ScSd3jE#<2FCipv2nN37(p!223-wT4!}N_;_g?nA ze-HmiyCUxEw{MM0V`F3Xq?Iv+v~U?0nr34~ z<}uU*X&WIiu}J9hk1#lUEnggmFz`ZTYM9nubBdW&4`e?0Fj#WNmKs3RAHROZeRUO_ z?ygxu7BN~igO`wm|GSeL51Fu^>@7>bC1(EIrtp49Fm5ho|G~k*Lg^|};+bzc(EljG z7C~=Sv2G}mM+1@WIPhJQ)H{xUao&4zmw><clqDTOAnW8(*8V)+B_15j+Gz6O^+*L!X4Um}z)EuPCj^hRj@oH_Ira4I%dXwD^q^70l-sY2$K{?rG-u8kixmC4xDOt=iV>^q&j&NfKN zw%*)~70Sw(6GPiU$nQf-9z}6OKb^k*QFIDeEyX>>O>R`{+C<&Nst3to{H6t>QW<42 zh8C63>eeUvza(IO}xmqjFZSjJ$30xfp*V7KmZlTWwH#faatv%`6<{GY# zew^m{!v^~WDfBorB$-0>)X2z4ue(vZ#!%)84-V5=@TK?mp6b%#DuCgX1KmK*Apa#QT#*M<)h<1J@SrZlUAh-JN0m2ma0onWL1Gfj#~x}-;q8X^?{(|i zhoGKgIXsz?jDs%pV_AZOYS>sUE}HfL&uMS`O%b4V_4JgYE-r>ig0ql14_c6RbH9F# z+`Y1rW)z);K__Bje^-Mq&ddf%uHdEI5m&G?gWk4PC!XsLj*fNAi9ynwE32bfDXpMk z^x~TjId8oj78+_(eIF16N>(crK-*@D=U7KbLD?a+JLH^rlVei_X$qo$*pB?fD{dc5 ze+q?Qp&`{b8<@8;F<6924^i4!^Q#;NKxCuYS_`ut&S*b6iwITg50peptlBIJ5|aO; z6Z~fZ5uEPm>?{$C0P4PLC#c2g=j*#kfr`#a{+GY97L~PftRVw&>~2Kn3fV@5lg>Fs zF|R+g`x>cWYA(EcM%Bcw-Gc2M9ibeo8p8VTDn-Z|dvrjb50VLd>p=QZqycgGROCZ+ zas3`-WwF06l2K!DoQnLT=iD|Jy@0{RQpC<0(rnN!zlPGja_t&z#%=e}AG5Q?c3)lG zLC-QiKE)lXj2e(V1CKY{__BKE;i(GiAM7X6cMTR$!=CeVV!IHQ=usUxgrHscMUHuv zARHd+T~ef=UL8L;$Iv7sD#|f&Gz8J|sFaixt4i_)B-8k@l#N#y2qD(FCM5 zTqiDigslKg0Db8T9di?;P&Maz3wJr!?Ap<;BIh1y2Rpm`x5hPjN*|sg=hrW2ar}7r z_l4aN3L2kNdsbbOi}WUZr{fbph(!4@36mRCF-!On9&l2_`J7|kG>Zk`7@#4{d^nc%PSA3 zNj^T>Kgc7TAv-50Ixdb5(Uh$FK8)~Z3N)W92Vw}wPHr@&)S|wTK3fg2A&uu&?dcc=WCm~T*H|ZWr0kF2H~EnBY7pv}M(*KHE^W}MOaXCF>^b-ngS-z> zNz&qNn&=#xEk#!v@OcY*>0uo!&eC~nK6>(`7B@4L`oR9xm3g#bk{+Z;sF{Jm^(Rk6 z;^N{+YD~ITn{?bvGGJ38OEDeUL^2z6aDH@cUDCKpMi5bio=mbizK1Kd1 zms>P|XM{sW-@bhlCAS)F8QDIL>XnwyJs5x`$gIzOuOWt`uMWYakJ?%ur~ zKG;z=w|!PtRs$0gpXw9#w6#pQ>WHv?+dTdOYZi3}S@S=p8&i zKy3G&ekEPfvWe%@%G@;*QU~~meAnk4=t4{=D=RPm3?D(F1Ooa%-j!uXv&6Fr5W#lx z^V6Zvpjs=ZopiD>tt3iF+5=SM$h;qiSNu1f-O$Fw!a6tkDGbDT2zJm>&UQOV(nGem zvg}E23IpZO`J86_=*V2jda;UWT2h@sxUZx{2xPd50Y3Z;$_F3%FVOV_y^NN?+!}ar z+oBqdEZT=&UZm(J!-s!jR>FkcbR^DnLUn|MgzkfZ9qLO@Cs#?yL#)r+e~b$olPGioP?i#5Y(A{aYIOpJ26z z?n0E*pG*5Eibv-2?l3X_Re=xt|NrxB`p<#?`!X59|Kr1uNBQpy{`@kU`k!m}-w$j% aLRFX#3}1E9N9zOrsjFx!KRtT-+W!OZ05hZj literal 0 HcmV?d00001 diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..25f61ed --- /dev/null +++ b/inference.py @@ -0,0 +1,53 @@ +import torch +from transformers import AutoModelForCausalLM + +from deepseek_vlm.models import VLChatProcessor, MultiModalityCausalLM +from deepseek_vlm.utils.io import load_pil_images + + +# specify the path to the model +model_path = "deepseek-ai/deepseek-vl-7b-chat" +vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) +tokenizer = vl_chat_processor.tokenizer + +vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) +vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval() + +conversation = [ + { + "role": "User", + "content": "Describe each stage of this image.", + "images": ["./images/training_pipelines.png"] + }, + { + "role": "Assistant", + "content": "" + } +] + + +# load images and prepare for inputs +pil_images = load_pil_images(conversation) +prepare_inputs = vl_chat_processor( + conversations=conversation, + images=pil_images, + force_batchify=True +).to(vl_gpt.device) + +# run image encoder to get the image embeddings +inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) + +# run the model to get the response +outputs = vl_gpt.language_model.generate( + inputs_embeds=inputs_embeds, + attention_mask=prepare_inputs.attention_mask, + pad_token_id=tokenizer.eos_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=512, + do_sample=False, + use_cache=True +) + +answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) +print(f"{prepare_inputs['sft_format'][0]}", answer) diff --git a/requirements.txt b/requirements.txt index 2a9d41f..589b5bb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,5 @@ -torch>=2.0 -tokenizers>=0.14.0 -transformers>=4.35.0 +transformers>=4.38.2 +timm>=0.9.16 +gradio>=4.13.0 accelerate -sympy==1.12 -pebble -timeout-decorator -attrdict +sentencepiece \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..7f8dbf0 --- /dev/null +++ b/setup.py @@ -0,0 +1,17 @@ +from setuptools import setup, find_packages + + +version = '1.0.0' +print(version) + +setup( + name='deepseek_vlm', + version=version, + description='DeekSeel-VLM', + author='HFAiLab', + license='MIT', + url='https://gitlab.deepseek.com/liuwen/deepseek_vl', + python_requires='>=3.8', + install_requires=['torch>=2.0'], + packages=find_packages(exclude=("images",)), +)