mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
- Run comfyui workflows directly inside other python applications using EmbeddedComfyClient. - Optional telemetry in prompts and models using anonymity preserving Plausible self-hosted or hosted. - Better OpenAPI schema - Basic support for distributed ComfyUI backends. Limitations: no progress reporting, no easy way to start your own distributed backend, requires RabbitMQ as a message broker.
1403 lines
56 KiB
Python
1403 lines
56 KiB
Python
# coding: utf-8
|
||
"""
|
||
comfyui
|
||
No description provided (generated by Openapi JSON Schema Generator https://github.com/openapi-json-schema-tools/openapi-json-schema-generator) # noqa: E501
|
||
The version of the OpenAPI document: 0.0.1
|
||
Generated by: https://github.com/openapi-json-schema-tools/openapi-json-schema-generator
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
import abc
|
||
import datetime
|
||
import dataclasses
|
||
import decimal
|
||
import enum
|
||
import email
|
||
import json
|
||
import os
|
||
import io
|
||
import atexit
|
||
from multiprocessing import pool
|
||
import re
|
||
import tempfile
|
||
import typing
|
||
import typing_extensions
|
||
from urllib import parse
|
||
import urllib3
|
||
from urllib3 import _collections, fields
|
||
|
||
|
||
from comfy.api import exceptions, rest, schemas, security_schemes, api_response
|
||
from comfy.api.configurations import api_configuration, schema_configuration as schema_configuration_
|
||
|
||
|
||
class JSONEncoder(json.JSONEncoder):
|
||
compact_separators = (',', ':')
|
||
|
||
def default(self, obj: typing.Any):
|
||
if isinstance(obj, str):
|
||
return str(obj)
|
||
elif isinstance(obj, float):
|
||
return obj
|
||
elif isinstance(obj, bool):
|
||
# must be before int check
|
||
return obj
|
||
elif isinstance(obj, int):
|
||
return obj
|
||
elif obj is None:
|
||
return None
|
||
elif isinstance(obj, (dict, schemas.immutabledict)):
|
||
return {key: self.default(val) for key, val in obj.items()}
|
||
elif isinstance(obj, (list, tuple)):
|
||
return [self.default(item) for item in obj]
|
||
raise exceptions.ApiValueError('Unable to prepare type {} for serialization'.format(obj.__class__.__name__))
|
||
|
||
|
||
class ParameterInType(enum.Enum):
|
||
QUERY = 'query'
|
||
HEADER = 'header'
|
||
PATH = 'path'
|
||
COOKIE = 'cookie'
|
||
|
||
|
||
class ParameterStyle(enum.Enum):
|
||
MATRIX = 'matrix'
|
||
LABEL = 'label'
|
||
FORM = 'form'
|
||
SIMPLE = 'simple'
|
||
SPACE_DELIMITED = 'spaceDelimited'
|
||
PIPE_DELIMITED = 'pipeDelimited'
|
||
DEEP_OBJECT = 'deepObject'
|
||
|
||
|
||
@dataclasses.dataclass
|
||
class PrefixSeparatorIterator:
|
||
# A class to store prefixes and separators for rfc6570 expansions
|
||
prefix: str
|
||
separator: str
|
||
first: bool = True
|
||
item_separator: str = dataclasses.field(init=False)
|
||
|
||
def __post_init__(self):
|
||
self.item_separator = self.separator if self.separator in {'.', '|', '%20'} else ','
|
||
|
||
def __iter__(self):
|
||
return self
|
||
|
||
def __next__(self):
|
||
if self.first:
|
||
self.first = False
|
||
return self.prefix
|
||
return self.separator
|
||
|
||
|
||
class ParameterSerializerBase:
|
||
@staticmethod
|
||
def __ref6570_item_value(in_data: typing.Any, percent_encode: bool):
|
||
"""
|
||
Get representation if str/float/int/None/items in list/ values in dict
|
||
None is returned if an item is undefined, use cases are value=
|
||
- None
|
||
- []
|
||
- {}
|
||
- [None, None None]
|
||
- {'a': None, 'b': None}
|
||
"""
|
||
if type(in_data) in {str, float, int}:
|
||
if percent_encode:
|
||
return parse.quote(str(in_data))
|
||
return str(in_data)
|
||
elif in_data is None:
|
||
# ignored by the expansion process https://datatracker.ietf.org/doc/html/rfc6570#section-3.2.1
|
||
return None
|
||
elif isinstance(in_data, list) and not in_data:
|
||
# ignored by the expansion process https://datatracker.ietf.org/doc/html/rfc6570#section-3.2.1
|
||
return None
|
||
elif isinstance(in_data, dict) and not in_data:
|
||
# ignored by the expansion process https://datatracker.ietf.org/doc/html/rfc6570#section-3.2.1
|
||
return None
|
||
raise exceptions.ApiValueError('Unable to generate a ref6570 item representation of {}'.format(in_data))
|
||
|
||
@staticmethod
|
||
def _to_dict(name: str, value: str):
|
||
return {name: value}
|
||
|
||
@classmethod
|
||
def __ref6570_str_float_int_expansion(
|
||
cls,
|
||
variable_name: str,
|
||
in_data: typing.Any,
|
||
explode: bool,
|
||
percent_encode: bool,
|
||
prefix_separator_iterator: PrefixSeparatorIterator,
|
||
var_name_piece: str,
|
||
named_parameter_expansion: bool
|
||
) -> str:
|
||
item_value = cls.__ref6570_item_value(in_data, percent_encode)
|
||
if item_value is None or (item_value == '' and prefix_separator_iterator.separator == ';'):
|
||
return next(prefix_separator_iterator) + var_name_piece
|
||
value_pair_equals = '=' if named_parameter_expansion else ''
|
||
return next(prefix_separator_iterator) + var_name_piece + value_pair_equals + item_value
|
||
|
||
@classmethod
|
||
def __ref6570_list_expansion(
|
||
cls,
|
||
variable_name: str,
|
||
in_data: typing.Any,
|
||
explode: bool,
|
||
percent_encode: bool,
|
||
prefix_separator_iterator: PrefixSeparatorIterator,
|
||
var_name_piece: str,
|
||
named_parameter_expansion: bool
|
||
) -> str:
|
||
item_values = [cls.__ref6570_item_value(v, percent_encode) for v in in_data]
|
||
item_values = [v for v in item_values if v is not None]
|
||
if not item_values:
|
||
# ignored by the expansion process https://datatracker.ietf.org/doc/html/rfc6570#section-3.2.1
|
||
return ""
|
||
value_pair_equals = '=' if named_parameter_expansion else ''
|
||
if not explode:
|
||
return (
|
||
next(prefix_separator_iterator) +
|
||
var_name_piece +
|
||
value_pair_equals +
|
||
prefix_separator_iterator.item_separator.join(item_values)
|
||
)
|
||
# exploded
|
||
return next(prefix_separator_iterator) + next(prefix_separator_iterator).join(
|
||
[var_name_piece + value_pair_equals + val for val in item_values]
|
||
)
|
||
|
||
@classmethod
|
||
def __ref6570_dict_expansion(
|
||
cls,
|
||
variable_name: str,
|
||
in_data: typing.Any,
|
||
explode: bool,
|
||
percent_encode: bool,
|
||
prefix_separator_iterator: PrefixSeparatorIterator,
|
||
var_name_piece: str,
|
||
named_parameter_expansion: bool
|
||
) -> str:
|
||
in_data_transformed = {key: cls.__ref6570_item_value(val, percent_encode) for key, val in in_data.items()}
|
||
in_data_transformed = {key: val for key, val in in_data_transformed.items() if val is not None}
|
||
if not in_data_transformed:
|
||
# ignored by the expansion process https://datatracker.ietf.org/doc/html/rfc6570#section-3.2.1
|
||
return ""
|
||
value_pair_equals = '=' if named_parameter_expansion else ''
|
||
if not explode:
|
||
return (
|
||
next(prefix_separator_iterator) +
|
||
var_name_piece + value_pair_equals +
|
||
prefix_separator_iterator.item_separator.join(
|
||
prefix_separator_iterator.item_separator.join(
|
||
item_pair
|
||
) for item_pair in in_data_transformed.items()
|
||
)
|
||
)
|
||
# exploded
|
||
return next(prefix_separator_iterator) + next(prefix_separator_iterator).join(
|
||
[key + '=' + val for key, val in in_data_transformed.items()]
|
||
)
|
||
|
||
@classmethod
|
||
def _ref6570_expansion(
|
||
cls,
|
||
variable_name: str,
|
||
in_data: typing.Any,
|
||
explode: bool,
|
||
percent_encode: bool,
|
||
prefix_separator_iterator: PrefixSeparatorIterator
|
||
) -> str:
|
||
"""
|
||
Separator is for separate variables like dict with explode true, not for array item separation
|
||
"""
|
||
named_parameter_expansion = prefix_separator_iterator.separator in {'&', ';'}
|
||
var_name_piece = variable_name if named_parameter_expansion else ''
|
||
if type(in_data) in {str, float, int}:
|
||
return cls.__ref6570_str_float_int_expansion(
|
||
variable_name,
|
||
in_data,
|
||
explode,
|
||
percent_encode,
|
||
prefix_separator_iterator,
|
||
var_name_piece,
|
||
named_parameter_expansion
|
||
)
|
||
elif in_data is None:
|
||
# ignored by the expansion process https://datatracker.ietf.org/doc/html/rfc6570#section-3.2.1
|
||
return ""
|
||
elif isinstance(in_data, list):
|
||
return cls.__ref6570_list_expansion(
|
||
variable_name,
|
||
in_data,
|
||
explode,
|
||
percent_encode,
|
||
prefix_separator_iterator,
|
||
var_name_piece,
|
||
named_parameter_expansion
|
||
)
|
||
elif isinstance(in_data, dict):
|
||
return cls.__ref6570_dict_expansion(
|
||
variable_name,
|
||
in_data,
|
||
explode,
|
||
percent_encode,
|
||
prefix_separator_iterator,
|
||
var_name_piece,
|
||
named_parameter_expansion
|
||
)
|
||
# bool, bytes, etc
|
||
raise exceptions.ApiValueError('Unable to generate a ref6570 representation of {}'.format(in_data))
|
||
|
||
|
||
class StyleFormSerializer(ParameterSerializerBase):
|
||
@classmethod
|
||
def _serialize_form(
|
||
cls,
|
||
in_data: typing.Union[None, int, float, str, bool, dict, list],
|
||
name: str,
|
||
explode: bool,
|
||
percent_encode: bool,
|
||
prefix_separator_iterator: typing.Optional[PrefixSeparatorIterator] = None
|
||
) -> str:
|
||
if prefix_separator_iterator is None:
|
||
prefix_separator_iterator = PrefixSeparatorIterator('', '&')
|
||
return cls._ref6570_expansion(
|
||
variable_name=name,
|
||
in_data=in_data,
|
||
explode=explode,
|
||
percent_encode=percent_encode,
|
||
prefix_separator_iterator=prefix_separator_iterator
|
||
)
|
||
|
||
|
||
class StyleSimpleSerializer(ParameterSerializerBase):
|
||
|
||
@classmethod
|
||
def _serialize_simple(
|
||
cls,
|
||
in_data: typing.Union[None, int, float, str, bool, dict, list],
|
||
name: str,
|
||
explode: bool,
|
||
percent_encode: bool
|
||
) -> str:
|
||
prefix_separator_iterator = PrefixSeparatorIterator('', ',')
|
||
return cls._ref6570_expansion(
|
||
variable_name=name,
|
||
in_data=in_data,
|
||
explode=explode,
|
||
percent_encode=percent_encode,
|
||
prefix_separator_iterator=prefix_separator_iterator
|
||
)
|
||
|
||
@classmethod
|
||
def _deserialize_simple(
|
||
cls,
|
||
in_data: str,
|
||
name: str,
|
||
explode: bool,
|
||
percent_encode: bool
|
||
) -> typing.Union[str, typing.List[str], typing.Dict[str, str]]:
|
||
raise NotImplementedError(
|
||
"Deserialization of style=simple has not yet been added. "
|
||
"If you need this how about you submit a PR adding it?"
|
||
)
|
||
|
||
|
||
class JSONDetector:
|
||
"""
|
||
Works for:
|
||
application/json
|
||
application/json; charset=UTF-8
|
||
application/json-patch+json
|
||
application/geo+json
|
||
"""
|
||
__json_content_type_pattern = re.compile("application/[^+]*[+]?(json);?.*")
|
||
|
||
@classmethod
|
||
def _content_type_is_json(cls, content_type: str) -> bool:
|
||
if cls.__json_content_type_pattern.match(content_type):
|
||
return True
|
||
return False
|
||
|
||
|
||
class Encoding:
|
||
content_type: str
|
||
headers: typing.Optional[typing.Dict[str, 'HeaderParameter']] = None
|
||
style: typing.Optional[ParameterStyle] = None
|
||
explode: bool = False
|
||
allow_reserved: bool = False
|
||
|
||
|
||
class MediaType:
|
||
"""
|
||
Used to store request and response body schema information
|
||
encoding:
|
||
A map between a property name and its encoding information.
|
||
The key, being the property name, MUST exist in the schema as a property.
|
||
The encoding object SHALL only apply to requestBody objects when the media type is
|
||
multipart or application/x-www-form-urlencoded.
|
||
"""
|
||
schema: typing.Optional[typing.Type[schemas.Schema]] = None
|
||
encoding: typing.Optional[typing.Dict[str, Encoding]] = None
|
||
|
||
|
||
class ParameterBase(JSONDetector):
|
||
in_type: ParameterInType
|
||
required: bool
|
||
style: typing.Optional[ParameterStyle]
|
||
explode: typing.Optional[bool]
|
||
allow_reserved: typing.Optional[bool]
|
||
schema: typing.Optional[typing.Type[schemas.Schema]]
|
||
content: typing.Optional[typing.Dict[str, typing.Type[MediaType]]]
|
||
|
||
_json_encoder = JSONEncoder()
|
||
|
||
def __init_subclass__(cls, **kwargs):
|
||
if cls.explode is None:
|
||
if cls.style is ParameterStyle.FORM:
|
||
cls.explode = True
|
||
else:
|
||
cls.explode = False
|
||
|
||
@classmethod
|
||
def _serialize_json(
|
||
cls,
|
||
in_data: typing.Union[None, int, float, str, bool, dict, list],
|
||
eliminate_whitespace: bool = False
|
||
) -> str:
|
||
if eliminate_whitespace:
|
||
return json.dumps(in_data, separators=cls._json_encoder.compact_separators)
|
||
return json.dumps(in_data)
|
||
|
||
_SERIALIZE_TYPES = typing.Union[
|
||
int,
|
||
float,
|
||
str,
|
||
datetime.date,
|
||
datetime.datetime,
|
||
None,
|
||
bool,
|
||
list,
|
||
tuple,
|
||
dict,
|
||
schemas.immutabledict
|
||
]
|
||
|
||
_JSON_TYPES = typing.Union[
|
||
int,
|
||
float,
|
||
str,
|
||
None,
|
||
bool,
|
||
typing.Tuple['_JSON_TYPES', ...],
|
||
schemas.immutabledict[str, '_JSON_TYPES'],
|
||
]
|
||
|
||
@dataclasses.dataclass
|
||
class PathParameter(ParameterBase, StyleSimpleSerializer):
|
||
name: str
|
||
required: bool = False
|
||
in_type: ParameterInType = ParameterInType.PATH
|
||
style: ParameterStyle = ParameterStyle.SIMPLE
|
||
explode: bool = False
|
||
allow_reserved: typing.Optional[bool] = None
|
||
schema: typing.Optional[typing.Type[schemas.Schema]] = None
|
||
content: typing.Optional[typing.Dict[str, typing.Type[MediaType]]] = None
|
||
|
||
@classmethod
|
||
def __serialize_label(
|
||
cls,
|
||
in_data: typing.Union[None, int, float, str, bool, dict, list]
|
||
) -> typing.Dict[str, str]:
|
||
prefix_separator_iterator = PrefixSeparatorIterator('.', '.')
|
||
value = cls._ref6570_expansion(
|
||
variable_name=cls.name,
|
||
in_data=in_data,
|
||
explode=cls.explode,
|
||
percent_encode=True,
|
||
prefix_separator_iterator=prefix_separator_iterator
|
||
)
|
||
return cls._to_dict(cls.name, value)
|
||
|
||
@classmethod
|
||
def __serialize_matrix(
|
||
cls,
|
||
in_data: typing.Union[None, int, float, str, bool, dict, list]
|
||
) -> typing.Dict[str, str]:
|
||
prefix_separator_iterator = PrefixSeparatorIterator(';', ';')
|
||
value = cls._ref6570_expansion(
|
||
variable_name=cls.name,
|
||
in_data=in_data,
|
||
explode=cls.explode,
|
||
percent_encode=True,
|
||
prefix_separator_iterator=prefix_separator_iterator
|
||
)
|
||
return cls._to_dict(cls.name, value)
|
||
|
||
@classmethod
|
||
def __serialize_simple(
|
||
cls,
|
||
in_data: typing.Union[None, int, float, str, bool, dict, list],
|
||
) -> typing.Dict[str, str]:
|
||
value = cls._serialize_simple(
|
||
in_data=in_data,
|
||
name=cls.name,
|
||
explode=cls.explode,
|
||
percent_encode=True
|
||
)
|
||
return cls._to_dict(cls.name, value)
|
||
|
||
@classmethod
|
||
def serialize(
|
||
cls,
|
||
in_data: _SERIALIZE_TYPES,
|
||
skip_validation: bool = False
|
||
) -> typing.Dict[str, str]:
|
||
if cls.schema:
|
||
cast_in_data = in_data if skip_validation else cls.schema.validate_base(in_data)
|
||
cast_in_data = cls._json_encoder.default(cast_in_data)
|
||
"""
|
||
simple -> path
|
||
path:
|
||
returns path_params: dict
|
||
label -> path
|
||
returns path_params
|
||
matrix -> path
|
||
returns path_params
|
||
"""
|
||
if cls.style:
|
||
if cls.style is ParameterStyle.SIMPLE:
|
||
return cls.__serialize_simple(cast_in_data)
|
||
elif cls.style is ParameterStyle.LABEL:
|
||
return cls.__serialize_label(cast_in_data)
|
||
elif cls.style is ParameterStyle.MATRIX:
|
||
return cls.__serialize_matrix(cast_in_data)
|
||
assert cls.content is not None
|
||
for content_type, media_type in cls.content.items():
|
||
assert media_type.schema is not None
|
||
cast_in_data = in_data if skip_validation else media_type.schema.validate_base(in_data)
|
||
cast_in_data = cls._json_encoder.default(cast_in_data)
|
||
if cls._content_type_is_json(content_type):
|
||
value = cls._serialize_json(cast_in_data)
|
||
return cls._to_dict(cls.name, value)
|
||
else:
|
||
raise NotImplementedError('Serialization of {} has not yet been implemented'.format(content_type))
|
||
raise ValueError('Invalid value for content, it was empty and must have 1 key value pair')
|
||
|
||
|
||
@dataclasses.dataclass
|
||
class QueryParameter(ParameterBase, StyleFormSerializer):
|
||
name: str
|
||
required: bool = False
|
||
in_type: ParameterInType = ParameterInType.QUERY
|
||
style: ParameterStyle = ParameterStyle.FORM
|
||
explode: typing.Optional[bool] = None
|
||
allow_reserved: typing.Optional[bool] = None
|
||
schema: typing.Optional[typing.Type[schemas.Schema]] = None
|
||
content: typing.Optional[typing.Dict[str, typing.Type[MediaType]]] = None
|
||
|
||
@classmethod
|
||
def __serialize_space_delimited(
|
||
cls,
|
||
in_data: typing.Union[None, int, float, str, bool, dict, list],
|
||
prefix_separator_iterator: typing.Optional[PrefixSeparatorIterator],
|
||
explode: bool
|
||
) -> typing.Dict[str, str]:
|
||
if prefix_separator_iterator is None:
|
||
prefix_separator_iterator = cls.get_prefix_separator_iterator()
|
||
value = cls._ref6570_expansion(
|
||
variable_name=cls.name,
|
||
in_data=in_data,
|
||
explode=explode,
|
||
percent_encode=True,
|
||
prefix_separator_iterator=prefix_separator_iterator
|
||
)
|
||
return cls._to_dict(cls.name, value)
|
||
|
||
@classmethod
|
||
def __serialize_pipe_delimited(
|
||
cls,
|
||
in_data: typing.Union[None, int, float, str, bool, dict, list],
|
||
prefix_separator_iterator: typing.Optional[PrefixSeparatorIterator],
|
||
explode: bool
|
||
) -> typing.Dict[str, str]:
|
||
if prefix_separator_iterator is None:
|
||
prefix_separator_iterator = cls.get_prefix_separator_iterator()
|
||
value = cls._ref6570_expansion(
|
||
variable_name=cls.name,
|
||
in_data=in_data,
|
||
explode=explode,
|
||
percent_encode=True,
|
||
prefix_separator_iterator=prefix_separator_iterator
|
||
)
|
||
return cls._to_dict(cls.name, value)
|
||
|
||
@classmethod
|
||
def __serialize_form(
|
||
cls,
|
||
in_data: typing.Union[None, int, float, str, bool, dict, list],
|
||
prefix_separator_iterator: typing.Optional[PrefixSeparatorIterator],
|
||
explode: bool
|
||
) -> typing.Dict[str, str]:
|
||
if prefix_separator_iterator is None:
|
||
prefix_separator_iterator = cls.get_prefix_separator_iterator()
|
||
value = cls._serialize_form(
|
||
in_data,
|
||
name=cls.name,
|
||
explode=explode,
|
||
percent_encode=True,
|
||
prefix_separator_iterator=prefix_separator_iterator
|
||
)
|
||
return cls._to_dict(cls.name, value)
|
||
|
||
@classmethod
|
||
def get_prefix_separator_iterator(cls) -> PrefixSeparatorIterator:
|
||
if cls.style is ParameterStyle.FORM:
|
||
return PrefixSeparatorIterator('?', '&')
|
||
elif cls.style is ParameterStyle.SPACE_DELIMITED:
|
||
return PrefixSeparatorIterator('', '%20')
|
||
elif cls.style is ParameterStyle.PIPE_DELIMITED:
|
||
return PrefixSeparatorIterator('', '|')
|
||
raise ValueError(f'No iterator possible for style={cls.style}')
|
||
|
||
@classmethod
|
||
def serialize(
|
||
cls,
|
||
in_data: _SERIALIZE_TYPES,
|
||
prefix_separator_iterator: typing.Optional[PrefixSeparatorIterator] = None,
|
||
skip_validation: bool = False
|
||
) -> typing.Dict[str, str]:
|
||
if cls.schema:
|
||
cast_in_data = in_data if skip_validation else cls.schema.validate_base(in_data)
|
||
cast_in_data = cls._json_encoder.default(cast_in_data)
|
||
"""
|
||
form -> query
|
||
query:
|
||
- GET/HEAD/DELETE: could use fields
|
||
- PUT/POST: must use urlencode to send parameters
|
||
returns fields: tuple
|
||
spaceDelimited -> query
|
||
returns fields
|
||
pipeDelimited -> query
|
||
returns fields
|
||
deepObject -> query, https://github.com/OAI/OpenAPI-Specification/issues/1706
|
||
returns fields
|
||
"""
|
||
if cls.style:
|
||
# TODO update query ones to omit setting values when [] {} or None is input
|
||
explode = cls.explode if cls.explode is not None else cls.style == ParameterStyle.FORM
|
||
if cls.style is ParameterStyle.FORM:
|
||
return cls.__serialize_form(cast_in_data, prefix_separator_iterator, explode)
|
||
elif cls.style is ParameterStyle.SPACE_DELIMITED:
|
||
return cls.__serialize_space_delimited(cast_in_data, prefix_separator_iterator, explode)
|
||
elif cls.style is ParameterStyle.PIPE_DELIMITED:
|
||
return cls.__serialize_pipe_delimited(cast_in_data, prefix_separator_iterator, explode)
|
||
if prefix_separator_iterator is None:
|
||
prefix_separator_iterator = cls.get_prefix_separator_iterator()
|
||
assert cls.content is not None
|
||
for content_type, media_type in cls.content.items():
|
||
assert media_type.schema is not None
|
||
cast_in_data = in_data if skip_validation else media_type.schema.validate_base(in_data)
|
||
cast_in_data = cls._json_encoder.default(cast_in_data)
|
||
if cls._content_type_is_json(content_type):
|
||
value = cls._serialize_json(cast_in_data, eliminate_whitespace=True)
|
||
return cls._to_dict(
|
||
cls.name,
|
||
next(prefix_separator_iterator) + cls.name + '=' + parse.quote(value)
|
||
)
|
||
else:
|
||
raise NotImplementedError('Serialization of {} has not yet been implemented'.format(content_type))
|
||
raise ValueError('Invalid value for content, it was empty and must have 1 key value pair')
|
||
|
||
|
||
@dataclasses.dataclass
|
||
class CookieParameter(ParameterBase, StyleFormSerializer):
|
||
name: str
|
||
required: bool = False
|
||
style: ParameterStyle = ParameterStyle.FORM
|
||
in_type: ParameterInType = ParameterInType.COOKIE
|
||
explode: typing.Optional[bool] = None
|
||
allow_reserved: typing.Optional[bool] = None
|
||
schema: typing.Optional[typing.Type[schemas.Schema]] = None
|
||
content: typing.Optional[typing.Dict[str, typing.Type[MediaType]]] = None
|
||
|
||
@classmethod
|
||
def serialize(
|
||
cls,
|
||
in_data: _SERIALIZE_TYPES,
|
||
skip_validation: bool = False
|
||
) -> typing.Dict[str, str]:
|
||
if cls.schema:
|
||
cast_in_data = in_data if skip_validation else cls.schema.validate_base(in_data)
|
||
cast_in_data = cls._json_encoder.default(cast_in_data)
|
||
"""
|
||
form -> cookie
|
||
returns fields: tuple
|
||
"""
|
||
if cls.style:
|
||
"""
|
||
TODO add escaping of comma, space, equals
|
||
or turn encoding on
|
||
"""
|
||
explode = cls.explode if cls.explode is not None else cls.style == ParameterStyle.FORM
|
||
value = cls._serialize_form(
|
||
cast_in_data,
|
||
explode=explode,
|
||
name=cls.name,
|
||
percent_encode=False,
|
||
prefix_separator_iterator=PrefixSeparatorIterator('', '&')
|
||
)
|
||
return cls._to_dict(cls.name, value)
|
||
assert cls.content is not None
|
||
for content_type, media_type in cls.content.items():
|
||
assert media_type.schema is not None
|
||
cast_in_data = in_data if skip_validation else media_type.schema.validate_base(in_data)
|
||
cast_in_data = cls._json_encoder.default(cast_in_data)
|
||
if cls._content_type_is_json(content_type):
|
||
value = cls._serialize_json(cast_in_data)
|
||
return cls._to_dict(cls.name, value)
|
||
else:
|
||
raise NotImplementedError('Serialization of {} has not yet been implemented'.format(content_type))
|
||
raise ValueError('Invalid value for content, it was empty and must have 1 key value pair')
|
||
|
||
|
||
class __HeaderParameterBase(ParameterBase, StyleSimpleSerializer):
|
||
style: ParameterStyle = ParameterStyle.SIMPLE
|
||
schema: typing.Optional[typing.Type[schemas.Schema]] = None
|
||
content: typing.Optional[typing.Dict[str, typing.Type[MediaType]]] = None
|
||
explode: bool = False
|
||
|
||
@staticmethod
|
||
def __to_headers(in_data: typing.Tuple[typing.Tuple[str, str], ...]) -> _collections.HTTPHeaderDict:
|
||
data = tuple(t for t in in_data if t)
|
||
headers = _collections.HTTPHeaderDict()
|
||
if not data:
|
||
return headers
|
||
headers.extend(data)
|
||
return headers
|
||
|
||
@classmethod
|
||
def serialize_with_name(
|
||
cls,
|
||
in_data: _SERIALIZE_TYPES,
|
||
name: str,
|
||
skip_validation: bool = False
|
||
) -> _collections.HTTPHeaderDict:
|
||
if cls.schema:
|
||
cast_in_data = in_data if skip_validation else cls.schema.validate_base(in_data)
|
||
cast_in_data = cls._json_encoder.default(cast_in_data)
|
||
"""
|
||
simple -> header
|
||
headers: PoolManager needs a mapping, tuple is close
|
||
returns headers: dict
|
||
"""
|
||
if cls.style:
|
||
value = cls._serialize_simple(cast_in_data, name, cls.explode, False)
|
||
return cls.__to_headers(((name, value),))
|
||
assert cls.content is not None
|
||
for content_type, media_type in cls.content.items():
|
||
assert media_type.schema is not None
|
||
cast_in_data = in_data if skip_validation else media_type.schema.validate_base(in_data)
|
||
cast_in_data = cls._json_encoder.default(cast_in_data)
|
||
if cls._content_type_is_json(content_type):
|
||
value = cls._serialize_json(cast_in_data)
|
||
return cls.__to_headers(((name, value),))
|
||
else:
|
||
raise NotImplementedError('Serialization of {} has not yet been implemented'.format(content_type))
|
||
raise ValueError('Invalid value for content, it was empty and must have 1 key value pair')
|
||
|
||
@classmethod
|
||
def deserialize(
|
||
cls,
|
||
in_data: str,
|
||
name: str
|
||
):
|
||
if cls.schema:
|
||
"""
|
||
simple -> header
|
||
headers: PoolManager needs a mapping, tuple is close
|
||
returns headers: dict
|
||
"""
|
||
if cls.style:
|
||
extracted_data = cls._deserialize_simple(in_data, name, cls.explode, False)
|
||
return cls.schema.validate_base(extracted_data)
|
||
assert cls.content is not None
|
||
for content_type, media_type in cls.content.items():
|
||
if cls._content_type_is_json(content_type):
|
||
cast_in_data: typing.Union[dict, list, None, int, float, str] = json.loads(in_data)
|
||
assert media_type.schema is not None
|
||
return media_type.schema.validate_base(cast_in_data)
|
||
else:
|
||
raise NotImplementedError('Deserialization of {} has not yet been implemented'.format(content_type))
|
||
raise ValueError('Invalid value for content, it was empty and must have 1 key value pair')
|
||
|
||
|
||
class HeaderParameterWithoutName(__HeaderParameterBase):
|
||
required: bool = False
|
||
style: ParameterStyle = ParameterStyle.SIMPLE
|
||
in_type: ParameterInType = ParameterInType.HEADER
|
||
explode: bool = False
|
||
allow_reserved: typing.Optional[bool] = None
|
||
schema: typing.Optional[typing.Type[schemas.Schema]] = None
|
||
content: typing.Optional[typing.Dict[str, typing.Type[MediaType]]] = None
|
||
|
||
@classmethod
|
||
def serialize(
|
||
cls,
|
||
in_data: _SERIALIZE_TYPES,
|
||
name: str,
|
||
skip_validation: bool = False
|
||
) -> _collections.HTTPHeaderDict:
|
||
return cls.serialize_with_name(
|
||
in_data,
|
||
name,
|
||
skip_validation=skip_validation
|
||
)
|
||
|
||
|
||
class HeaderParameter(__HeaderParameterBase):
|
||
name: str
|
||
required: bool = False
|
||
style: ParameterStyle = ParameterStyle.SIMPLE
|
||
in_type: ParameterInType = ParameterInType.HEADER
|
||
explode: bool = False
|
||
allow_reserved: typing.Optional[bool] = None
|
||
schema: typing.Optional[typing.Type[schemas.Schema]] = None
|
||
content: typing.Optional[typing.Dict[str, typing.Type[MediaType]]] = None
|
||
|
||
@classmethod
|
||
def serialize(
|
||
cls,
|
||
in_data: _SERIALIZE_TYPES,
|
||
skip_validation: bool = False
|
||
) -> _collections.HTTPHeaderDict:
|
||
return cls.serialize_with_name(
|
||
in_data,
|
||
cls.name,
|
||
skip_validation=skip_validation
|
||
)
|
||
|
||
T = typing.TypeVar("T", bound=api_response.ApiResponse)
|
||
|
||
|
||
class OpenApiResponse(typing.Generic[T], JSONDetector, abc.ABC):
|
||
__filename_content_disposition_pattern = re.compile('filename="(.+?)"')
|
||
content: typing.Optional[typing.Dict[str, typing.Type[MediaType]]] = None
|
||
headers: typing.Optional[typing.Dict[str, typing.Type[HeaderParameterWithoutName]]] = None
|
||
headers_schema: typing.Optional[typing.Type[schemas.Schema]] = None
|
||
|
||
@classmethod
|
||
@abc.abstractmethod
|
||
def get_response(cls, response, headers, body) -> T: ...
|
||
|
||
@staticmethod
|
||
def __deserialize_json(response: urllib3.HTTPResponse) -> typing.Any:
|
||
# python must be >= 3.9 so we can pass in bytes into json.loads
|
||
return json.loads(response.data)
|
||
|
||
@staticmethod
|
||
def __file_name_from_response_url(response_url: typing.Optional[str]) -> typing.Optional[str]:
|
||
if response_url is None:
|
||
return None
|
||
url_path = parse.urlparse(response_url).path
|
||
if url_path:
|
||
path_basename = os.path.basename(url_path)
|
||
if path_basename:
|
||
_filename, ext = os.path.splitext(path_basename)
|
||
if ext:
|
||
return path_basename
|
||
return None
|
||
|
||
@classmethod
|
||
def __file_name_from_content_disposition(cls, content_disposition: typing.Optional[str]) -> typing.Optional[str]:
|
||
if content_disposition is None:
|
||
return None
|
||
match = cls.__filename_content_disposition_pattern.search(content_disposition)
|
||
if not match:
|
||
return None
|
||
return match.group(1)
|
||
|
||
@classmethod
|
||
def __deserialize_application_octet_stream(
|
||
cls, response: urllib3.HTTPResponse
|
||
) -> typing.Union[bytes, io.BufferedReader]:
|
||
"""
|
||
urllib3 use cases:
|
||
1. when preload_content=True (stream=False) then supports_chunked_reads is False and bytes are returned
|
||
2. when preload_content=False (stream=True) then supports_chunked_reads is True and
|
||
a file will be written and returned
|
||
"""
|
||
if response.supports_chunked_reads():
|
||
file_name = (
|
||
cls.__file_name_from_content_disposition(response.headers.get('content-disposition'))
|
||
or cls.__file_name_from_response_url(response.geturl())
|
||
)
|
||
|
||
if file_name is None:
|
||
_fd, path = tempfile.mkstemp()
|
||
else:
|
||
path = os.path.join(tempfile.gettempdir(), file_name)
|
||
|
||
with open(path, 'wb') as write_file:
|
||
chunk_size = 1024
|
||
while True:
|
||
data = response.read(chunk_size)
|
||
if not data:
|
||
break
|
||
write_file.write(data)
|
||
# release_conn is needed for streaming connections only
|
||
response.release_conn()
|
||
new_file = open(path, 'rb')
|
||
return new_file
|
||
else:
|
||
return response.data
|
||
|
||
@staticmethod
|
||
def __deserialize_multipart_form_data(
|
||
response: urllib3.HTTPResponse
|
||
) -> typing.Dict[str, typing.Any]:
|
||
msg = email.message_from_bytes(response.data)
|
||
return {
|
||
part.get_param("name", header="Content-Disposition"): part.get_payload(
|
||
decode=True
|
||
).decode(part.get_content_charset())
|
||
if part.get_content_charset()
|
||
else part.get_payload()
|
||
for part in msg.get_payload()
|
||
}
|
||
|
||
@classmethod
|
||
def deserialize(cls, response: urllib3.HTTPResponse, configuration: schema_configuration_.SchemaConfiguration) -> T:
|
||
content_type = response.headers.get('content-type')
|
||
deserialized_body = schemas.unset
|
||
streamed = response.supports_chunked_reads()
|
||
|
||
deserialized_headers: typing.Union[schemas.Unset, typing.Dict[str, typing.Any]] = schemas.unset
|
||
if cls.headers is not None and cls.headers_schema is not None:
|
||
deserialized_headers = {}
|
||
for header_name, header_param in cls.headers.items():
|
||
header_value = response.headers.get(header_name)
|
||
if header_value is None:
|
||
continue
|
||
header_value = header_param.deserialize(header_value, header_name)
|
||
deserialized_headers[header_name] = header_value
|
||
deserialized_headers = cls.headers_schema.validate_base(deserialized_headers, configuration=configuration)
|
||
|
||
if cls.content is not None:
|
||
if content_type not in cls.content:
|
||
raise exceptions.ApiValueError(
|
||
f"Invalid content_type returned. Content_type='{content_type}' was returned "
|
||
f"when only {str(set(cls.content))} are defined for status_code={str(response.status)}"
|
||
)
|
||
body_schema = cls.content[content_type].schema
|
||
if body_schema is None:
|
||
# some specs do not define response content media type schemas
|
||
return cls.get_response(
|
||
response=response,
|
||
headers=deserialized_headers,
|
||
body=schemas.unset
|
||
)
|
||
|
||
if cls._content_type_is_json(content_type):
|
||
body_data = cls.__deserialize_json(response)
|
||
elif content_type == 'application/octet-stream':
|
||
body_data = cls.__deserialize_application_octet_stream(response)
|
||
elif content_type.startswith('multipart/form-data'):
|
||
body_data = cls.__deserialize_multipart_form_data(response)
|
||
content_type = 'multipart/form-data'
|
||
elif content_type == 'application/x-pem-file':
|
||
body_data = response.data.decode()
|
||
else:
|
||
raise NotImplementedError('Deserialization of {} has not yet been implemented'.format(content_type))
|
||
body_schema = schemas.get_class(body_schema)
|
||
if body_schema is schemas.BinarySchema:
|
||
deserialized_body = body_schema.validate_base(body_data)
|
||
else:
|
||
deserialized_body = body_schema.validate_base(
|
||
body_data, configuration=configuration)
|
||
elif streamed:
|
||
response.release_conn()
|
||
|
||
return cls.get_response(
|
||
response=response,
|
||
headers=deserialized_headers,
|
||
body=deserialized_body
|
||
)
|
||
|
||
|
||
@dataclasses.dataclass
|
||
class ApiClient:
|
||
"""Generic API client for OpenAPI client library builds.
|
||
|
||
OpenAPI generic API client. This client handles the client-
|
||
server communication, and is invariant across implementations. Specifics of
|
||
the methods and models for each application are generated from the OpenAPI
|
||
templates.
|
||
|
||
NOTE: This class is auto generated by OpenAPI JSON Schema Generator.
|
||
Ref: https://github.com/openapi-json-schema-tools/openapi-json-schema-generator
|
||
Do not edit the class manually.
|
||
|
||
:param configuration: api_configuration.ApiConfiguration object for this client
|
||
:param schema_configuration: schema_configuration_.SchemaConfiguration object for this client
|
||
:param default_headers: any default headers to include when making calls to the API.
|
||
:param pool_threads: The number of threads to use for async requests
|
||
to the API. More threads means more concurrent API requests.
|
||
"""
|
||
configuration: api_configuration.ApiConfiguration = dataclasses.field(
|
||
default_factory=lambda: api_configuration.ApiConfiguration())
|
||
schema_configuration: schema_configuration_.SchemaConfiguration = dataclasses.field(
|
||
default_factory=lambda: schema_configuration_.SchemaConfiguration())
|
||
default_headers: _collections.HTTPHeaderDict = dataclasses.field(
|
||
default_factory=lambda: _collections.HTTPHeaderDict())
|
||
pool_threads: int = 1
|
||
user_agent: str = 'OpenAPI-JSON-Schema-Generator/1.0.0/python'
|
||
rest_client: rest.RESTClientObject = dataclasses.field(init=False)
|
||
|
||
def __post_init__(self):
|
||
self._pool = None
|
||
self.rest_client = rest.RESTClientObject(self.configuration)
|
||
|
||
def __enter__(self):
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_value, traceback):
|
||
self.close()
|
||
|
||
def close(self):
|
||
if self._pool:
|
||
self._pool.close()
|
||
self._pool.join()
|
||
self._pool = None
|
||
if hasattr(atexit, 'unregister'):
|
||
atexit.unregister(self.close)
|
||
|
||
@property
|
||
def pool(self):
|
||
"""Create thread pool on first request
|
||
avoids instantiating unused threadpool for blocking clients.
|
||
"""
|
||
if self._pool is None:
|
||
atexit.register(self.close)
|
||
self._pool = pool.ThreadPool(self.pool_threads)
|
||
return self._pool
|
||
|
||
def set_default_header(self, header_name: str, header_value: str):
|
||
self.default_headers[header_name] = header_value
|
||
|
||
def call_api(
|
||
self,
|
||
resource_path: str,
|
||
method: str,
|
||
host: str,
|
||
query_params_suffix: typing.Optional[str] = None,
|
||
headers: typing.Optional[_collections.HTTPHeaderDict] = None,
|
||
body: typing.Union[str, bytes, None] = None,
|
||
fields: typing.Optional[typing.Tuple[rest.RequestField, ...]] = None,
|
||
security_requirement_object: typing.Optional[security_schemes.SecurityRequirementObject] = None,
|
||
stream: bool = False,
|
||
timeout: typing.Union[int, float, typing.Tuple, None] = None,
|
||
) -> urllib3.HTTPResponse:
|
||
"""Makes the HTTP request (synchronous) and returns deserialized data.
|
||
|
||
:param resource_path: Path to method endpoint.
|
||
:param method: Method to call.
|
||
:param headers: Header parameters to be
|
||
placed in the request header.
|
||
:param body: Request body.
|
||
:param fields: Request post form parameters,
|
||
for `application/x-www-form-urlencoded`, `multipart/form-data`
|
||
:param security_requirement_object: The security requirement object, used to apply auth when making the call
|
||
:param async_req: execute request asynchronously
|
||
:param stream: if True, the urllib3.HTTPResponse object will
|
||
be returned without reading/decoding response
|
||
data. Also when True, if the openapi spec describes a file download,
|
||
the data will be written to a local filesystem file and the schemas.BinarySchema
|
||
instance will also inherit from FileSchema and schemas.FileIO
|
||
Default is False.
|
||
:type stream: bool, optional
|
||
:param timeout: timeout setting for this request. If one
|
||
number provided, it will be total request
|
||
timeout. It can also be a pair (tuple) of
|
||
(connection, read) timeouts.
|
||
:param host: api endpoint host
|
||
:return:
|
||
the method will return the response directly.
|
||
"""
|
||
# header parameters
|
||
used_headers = _collections.HTTPHeaderDict(self.default_headers)
|
||
user_agent_key = 'User-Agent'
|
||
if user_agent_key not in used_headers and self.user_agent:
|
||
used_headers[user_agent_key] = self.user_agent
|
||
|
||
# auth setting
|
||
self.update_params_for_auth(
|
||
used_headers,
|
||
security_requirement_object,
|
||
resource_path,
|
||
method,
|
||
body,
|
||
query_params_suffix
|
||
)
|
||
|
||
# must happen after auth setting in case user is overriding those
|
||
if headers:
|
||
used_headers.update(headers)
|
||
|
||
# request url
|
||
url = host + resource_path
|
||
if query_params_suffix:
|
||
url += query_params_suffix
|
||
|
||
# perform request and return response
|
||
response = self.request(
|
||
method,
|
||
url,
|
||
headers=used_headers,
|
||
fields=fields,
|
||
body=body,
|
||
stream=stream,
|
||
timeout=timeout,
|
||
)
|
||
return response
|
||
|
||
def request(
|
||
self,
|
||
method: str,
|
||
url: str,
|
||
headers: typing.Optional[_collections.HTTPHeaderDict] = None,
|
||
fields: typing.Optional[typing.Tuple[rest.RequestField, ...]] = None,
|
||
body: typing.Union[str, bytes, None] = None,
|
||
stream: bool = False,
|
||
timeout: typing.Union[int, float, typing.Tuple, None] = None,
|
||
) -> urllib3.HTTPResponse:
|
||
"""Makes the HTTP request using RESTClient."""
|
||
if method == "get":
|
||
return self.rest_client.get(url,
|
||
stream=stream,
|
||
timeout=timeout,
|
||
headers=headers)
|
||
elif method == "head":
|
||
return self.rest_client.head(url,
|
||
stream=stream,
|
||
timeout=timeout,
|
||
headers=headers)
|
||
elif method == "options":
|
||
return self.rest_client.options(url,
|
||
headers=headers,
|
||
fields=fields,
|
||
stream=stream,
|
||
timeout=timeout,
|
||
body=body)
|
||
elif method == "post":
|
||
return self.rest_client.post(url,
|
||
headers=headers,
|
||
fields=fields,
|
||
stream=stream,
|
||
timeout=timeout,
|
||
body=body)
|
||
elif method == "put":
|
||
return self.rest_client.put(url,
|
||
headers=headers,
|
||
fields=fields,
|
||
stream=stream,
|
||
timeout=timeout,
|
||
body=body)
|
||
elif method == "patch":
|
||
return self.rest_client.patch(url,
|
||
headers=headers,
|
||
fields=fields,
|
||
stream=stream,
|
||
timeout=timeout,
|
||
body=body)
|
||
elif method == "delete":
|
||
return self.rest_client.delete(url,
|
||
headers=headers,
|
||
stream=stream,
|
||
timeout=timeout,
|
||
body=body)
|
||
else:
|
||
raise exceptions.ApiValueError(
|
||
"http method must be `GET`, `HEAD`, `OPTIONS`,"
|
||
" `POST`, `PATCH`, `PUT` or `DELETE`."
|
||
)
|
||
|
||
def update_params_for_auth(
|
||
self,
|
||
headers: _collections.HTTPHeaderDict,
|
||
security_requirement_object: typing.Optional[security_schemes.SecurityRequirementObject],
|
||
resource_path: str,
|
||
method: str,
|
||
body: typing.Union[str, bytes, None] = None,
|
||
query_params_suffix: typing.Optional[str] = None
|
||
):
|
||
"""Updates header and query params based on authentication setting.
|
||
|
||
:param headers: Header parameters dict to be updated.
|
||
:param security_requirement_object: the openapi security requirement object
|
||
:param resource_path: A string representation of the HTTP request resource path.
|
||
:param method: A string representation of the HTTP request method.
|
||
:param body: A object representing the body of the HTTP request.
|
||
The object type is the return value of _encoder.default().
|
||
"""
|
||
return
|
||
|
||
@dataclasses.dataclass
|
||
class Api:
|
||
"""NOTE: This class is auto generated by OpenAPI JSON Schema Generator
|
||
Ref: https://github.com/openapi-json-schema-tools/openapi-json-schema-generator
|
||
|
||
Do not edit the class manually.
|
||
"""
|
||
api_client: ApiClient = dataclasses.field(default_factory=lambda: ApiClient())
|
||
|
||
@staticmethod
|
||
def _get_used_path(
|
||
used_path: str,
|
||
path_parameters: typing.Tuple[typing.Type[PathParameter], ...] = (),
|
||
path_params: typing.Optional[typing.Mapping[str, schemas.OUTPUT_BASE_TYPES]] = None,
|
||
query_parameters: typing.Tuple[typing.Type[QueryParameter], ...] = (),
|
||
query_params: typing.Optional[typing.Mapping[str, schemas.OUTPUT_BASE_TYPES]] = None,
|
||
skip_validation: bool = False
|
||
) -> typing.Tuple[str, str]:
|
||
used_path_params = {}
|
||
if path_params is not None:
|
||
for path_parameter in path_parameters:
|
||
parameter_data = path_params.get(path_parameter.name, schemas.unset)
|
||
if isinstance(parameter_data, schemas.Unset):
|
||
continue
|
||
assert not isinstance(parameter_data, (bytes, schemas.FileIO))
|
||
serialized_data = path_parameter.serialize(parameter_data, skip_validation=skip_validation)
|
||
used_path_params.update(serialized_data)
|
||
|
||
for k, v in used_path_params.items():
|
||
used_path = used_path.replace('{%s}' % k, v)
|
||
|
||
query_params_suffix = ""
|
||
if query_params is not None:
|
||
prefix_separator_iterator = None
|
||
for query_parameter in query_parameters:
|
||
parameter_data = query_params.get(query_parameter.name, schemas.unset)
|
||
if isinstance(parameter_data, schemas.Unset):
|
||
continue
|
||
if prefix_separator_iterator is None:
|
||
prefix_separator_iterator = query_parameter.get_prefix_separator_iterator()
|
||
assert not isinstance(parameter_data, (bytes, schemas.FileIO))
|
||
serialized_data = query_parameter.serialize(
|
||
parameter_data,
|
||
prefix_separator_iterator=prefix_separator_iterator,
|
||
skip_validation=skip_validation
|
||
)
|
||
for serialized_value in serialized_data.values():
|
||
query_params_suffix += serialized_value
|
||
return used_path, query_params_suffix
|
||
|
||
@staticmethod
|
||
def _get_headers(
|
||
header_parameters: typing.Tuple[typing.Type[HeaderParameter], ...] = (),
|
||
header_params: typing.Optional[typing.Mapping[str, schemas.OUTPUT_BASE_TYPES]] = None,
|
||
accept_content_types: typing.Tuple[str, ...] = (),
|
||
skip_validation: bool = False
|
||
) -> _collections.HTTPHeaderDict:
|
||
headers = _collections.HTTPHeaderDict()
|
||
if header_params is not None:
|
||
for parameter in header_parameters:
|
||
parameter_data = header_params.get(parameter.name, schemas.unset)
|
||
if isinstance(parameter_data, schemas.Unset):
|
||
continue
|
||
assert not isinstance(parameter_data, (bytes, schemas.FileIO))
|
||
serialized_data = parameter.serialize(parameter_data, skip_validation=skip_validation)
|
||
headers.extend(serialized_data)
|
||
if accept_content_types:
|
||
for accept_content_type in accept_content_types:
|
||
headers.add('Accept', accept_content_type)
|
||
return headers
|
||
|
||
def _get_fields_and_body(
|
||
self,
|
||
request_body: typing.Type[RequestBody],
|
||
body: typing.Union[schemas.INPUT_TYPES_ALL, schemas.Unset],
|
||
content_type: str,
|
||
headers: _collections.HTTPHeaderDict
|
||
):
|
||
if request_body.required and body is schemas.unset:
|
||
raise exceptions.ApiValueError(
|
||
'The required body parameter has an invalid value of: unset. Set a valid value instead')
|
||
|
||
if isinstance(body, schemas.Unset):
|
||
return None, None
|
||
|
||
serialized_fields = None
|
||
serialized_body = None
|
||
serialized_data = request_body.serialize(body, content_type, configuration=self.api_client.schema_configuration)
|
||
headers.add('Content-Type', content_type)
|
||
if 'fields' in serialized_data:
|
||
serialized_fields = serialized_data['fields']
|
||
elif 'body' in serialized_data:
|
||
serialized_body = serialized_data['body']
|
||
return serialized_fields, serialized_body
|
||
|
||
@staticmethod
|
||
def _verify_response_status(response: api_response.ApiResponse):
|
||
if not 200 <= response.response.status <= 399:
|
||
raise exceptions.ApiException(
|
||
status=response.response.status,
|
||
reason=response.response.reason,
|
||
api_response=response
|
||
)
|
||
|
||
|
||
class SerializedRequestBody(typing.TypedDict, total=False):
|
||
body: typing.Union[str, bytes]
|
||
fields: typing.Tuple[rest.RequestField, ...]
|
||
|
||
|
||
class RequestBody(StyleFormSerializer, JSONDetector):
|
||
"""
|
||
A request body parameter
|
||
content: content_type to MediaType schemas.Schema info
|
||
"""
|
||
__json_encoder = JSONEncoder()
|
||
__plain_txt_content_types = {'text/plain', 'application/x-pem-file'}
|
||
content: typing.Dict[str, typing.Type[MediaType]]
|
||
required: bool = False
|
||
|
||
@classmethod
|
||
def __serialize_json(
|
||
cls,
|
||
in_data: _JSON_TYPES
|
||
) -> SerializedRequestBody:
|
||
in_data = cls.__json_encoder.default(in_data)
|
||
json_str = json.dumps(in_data, separators=(",", ":"), ensure_ascii=False).encode(
|
||
"utf-8"
|
||
)
|
||
return {'body': json_str}
|
||
|
||
@staticmethod
|
||
def __serialize_text_plain(in_data: typing.Union[int, float, str]) -> SerializedRequestBody:
|
||
return {'body': str(in_data)}
|
||
|
||
@classmethod
|
||
def __multipart_json_item(cls, key: str, value: _JSON_TYPES) -> rest.RequestField:
|
||
json_value = cls.__json_encoder.default(value)
|
||
request_field = rest.RequestField(name=key, data=json.dumps(json_value))
|
||
request_field.make_multipart(content_type='application/json')
|
||
return request_field
|
||
|
||
@classmethod
|
||
def __multipart_form_item(cls, key: str, value: typing.Union[_JSON_TYPES, bytes, schemas.FileIO]) -> rest.RequestField:
|
||
if isinstance(value, str):
|
||
request_field = rest.RequestField(name=key, data=str(value))
|
||
request_field.make_multipart(content_type='text/plain')
|
||
elif isinstance(value, bytes):
|
||
request_field = rest.RequestField(name=key, data=value)
|
||
request_field.make_multipart(content_type='application/octet-stream')
|
||
elif isinstance(value, schemas.FileIO):
|
||
# TODO use content.encoding to limit allowed content types if they are present
|
||
urllib3_request_field = rest.RequestField.from_tuples(key, (os.path.basename(str(value.name)), value.read()))
|
||
request_field = typing.cast(rest.RequestField, urllib3_request_field)
|
||
value.close()
|
||
else:
|
||
request_field = cls.__multipart_json_item(key=key, value=value)
|
||
return request_field
|
||
|
||
@classmethod
|
||
def __serialize_multipart_form_data(
|
||
cls, in_data: schemas.immutabledict[str, typing.Union[_JSON_TYPES, bytes, schemas.FileIO]]
|
||
) -> SerializedRequestBody:
|
||
"""
|
||
In a multipart/form-data request body, each schema property, or each element of a schema array property,
|
||
takes a section in the payload with an internal header as defined by RFC7578. The serialization strategy
|
||
for each property of a multipart/form-data request body can be specified in an associated Encoding Object.
|
||
|
||
When passing in multipart types, boundaries MAY be used to separate sections of the content being
|
||
transferred – thus, the following default Content-Types are defined for multipart:
|
||
|
||
If the (object) property is a primitive, or an array of primitive values, the default Content-Type is text/plain
|
||
If the property is complex, or an array of complex values, the default Content-Type is application/json
|
||
Question: how is the array of primitives encoded?
|
||
If the property is a type: string with a contentEncoding, the default Content-Type is application/octet-stream
|
||
"""
|
||
fields = []
|
||
for key, value in in_data.items():
|
||
if isinstance(value, tuple):
|
||
if value:
|
||
# values use explode = True, so the code makes a rest.RequestField for each item with name=key
|
||
for item in value:
|
||
request_field = cls.__multipart_form_item(key=key, value=item)
|
||
fields.append(request_field)
|
||
else:
|
||
# send an empty array as json because exploding will not send it
|
||
request_field = cls.__multipart_json_item(key=key, value=value) # type: ignore
|
||
fields.append(request_field)
|
||
else:
|
||
request_field = cls.__multipart_form_item(key=key, value=value)
|
||
fields.append(request_field)
|
||
|
||
return {'fields': tuple(fields)}
|
||
|
||
@staticmethod
|
||
def __serialize_application_octet_stream(in_data: typing.Union[schemas.FileIO, bytes]) -> SerializedRequestBody:
|
||
if isinstance(in_data, bytes):
|
||
return {'body': in_data}
|
||
# schemas.FileIO type
|
||
used_in_data = in_data.read()
|
||
in_data.close()
|
||
return {'body': used_in_data}
|
||
|
||
@classmethod
|
||
def __serialize_application_x_www_form_data(
|
||
cls, in_data: schemas.immutabledict[str, _JSON_TYPES]
|
||
) -> SerializedRequestBody:
|
||
"""
|
||
POST submission of form data in body
|
||
"""
|
||
cast_in_data = cls.__json_encoder.default(in_data)
|
||
value = cls._serialize_form(cast_in_data, name='', explode=True, percent_encode=True)
|
||
return {'body': value}
|
||
|
||
@classmethod
|
||
def serialize(
|
||
cls, in_data: schemas.INPUT_TYPES_ALL, content_type: str, configuration: typing.Optional[schema_configuration_.SchemaConfiguration] = None
|
||
) -> SerializedRequestBody:
|
||
"""
|
||
If a str is returned then the result will be assigned to data when making the request
|
||
If a tuple is returned then the result will be used as fields input in encode_multipart_formdata
|
||
Return a tuple of
|
||
|
||
The key of the return dict is
|
||
- body for application/json
|
||
- encode_multipart and fields for multipart/form-data
|
||
"""
|
||
media_type = cls.content[content_type]
|
||
assert media_type.schema is not None
|
||
schema = schemas.get_class(media_type.schema)
|
||
used_configuration = configuration if configuration is not None else schema_configuration_.SchemaConfiguration()
|
||
cast_in_data = schema.validate_base(in_data, configuration=used_configuration)
|
||
# TODO check for and use encoding if it exists
|
||
# and content_type is multipart or application/x-www-form-urlencoded
|
||
if cls._content_type_is_json(content_type):
|
||
if isinstance(cast_in_data, (schemas.FileIO, bytes)):
|
||
raise ValueError(f"Invalid input data type. Data must be int/float/str/bool/None/tuple/immutabledict and it was type {type(cast_in_data)}")
|
||
return cls.__serialize_json(cast_in_data)
|
||
elif content_type in cls.__plain_txt_content_types:
|
||
if not isinstance(cast_in_data, (int, float, str)):
|
||
raise ValueError(f"Unable to serialize type {type(cast_in_data)} to text/plain")
|
||
return cls.__serialize_text_plain(cast_in_data)
|
||
elif content_type == 'multipart/form-data':
|
||
if not isinstance(cast_in_data, schemas.immutabledict):
|
||
raise ValueError(f"Unable to serialize {cast_in_data} to multipart/form-data because it is not a dict of data")
|
||
return cls.__serialize_multipart_form_data(cast_in_data)
|
||
elif content_type == 'application/x-www-form-urlencoded':
|
||
if not isinstance(cast_in_data, schemas.immutabledict):
|
||
raise ValueError(
|
||
f"Unable to serialize {cast_in_data} to application/x-www-form-urlencoded because it is not a dict of data")
|
||
return cls.__serialize_application_x_www_form_data(cast_in_data)
|
||
elif content_type == 'application/octet-stream':
|
||
if not isinstance(cast_in_data, (schemas.FileIO, bytes)):
|
||
raise ValueError(f"Invalid input data type. Data must be bytes or File for content_type={content_type}")
|
||
return cls.__serialize_application_octet_stream(cast_in_data)
|
||
raise NotImplementedError('Serialization has not yet been implemented for {}'.format(content_type))
|