Source code for geoalchemy2.admin.dialects.mysql

"""This module defines specific functions for MySQL dialect."""

from sqlalchemy import text
from sqlalchemy.dialects.mysql.base import ischema_names as _mysql_ischema_names
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.sqltypes import NullType

from geoalchemy2 import functions
from geoalchemy2.admin.dialects.common import _check_spatial_type
from geoalchemy2.admin.dialects.common import _spatial_idx_name
from geoalchemy2.admin.dialects.common import compile_bin_literal
from geoalchemy2.admin.dialects.common import setup_create_drop
from geoalchemy2.types import Geography
from geoalchemy2.types import Geometry

# Register Geometry, Geography and Raster to SQLAlchemy's reflection subsystems.
_mysql_ischema_names["geometry"] = Geometry
_mysql_ischema_names["point"] = Geometry
_mysql_ischema_names["linestring"] = Geometry
_mysql_ischema_names["polygon"] = Geometry
_mysql_ischema_names["multipoint"] = Geometry
_mysql_ischema_names["multilinestring"] = Geometry
_mysql_ischema_names["multipolygon"] = Geometry
_mysql_ischema_names["geometrycollection"] = Geometry


_POSSIBLE_TYPES = [
    "geometry",
    "point",
    "linestring",
    "polygon",
    "multipoint",
    "multilinestring",
    "multipolygon",
    "geometrycollection",
]


[docs] def reflect_geometry_column(inspector, table, column_info): """Reflect a column of type Geometry with Postgresql dialect.""" if not isinstance(column_info.get("type"), Geometry | NullType): return column_name = column_info.get("name") schema = table.schema or inspector.default_schema_name select_srid = "-1, " if inspector.dialect.name == "mariadb" else "SRS_ID, " # Check geometry type, SRID and if the column is nullable geometry_type_query = f"""SELECT DATA_TYPE, {select_srid}IS_NULLABLE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = '{table.name}' and COLUMN_NAME = '{column_name}'""" if schema is not None: geometry_type_query += f""" and table_schema = '{schema}'""" geometry_type, srid, nullable_str = inspector.bind.execute(text(geometry_type_query)).one() is_nullable = str(nullable_str).lower() == "yes" if geometry_type not in _POSSIBLE_TYPES: return # pragma: no cover # Check if the column has spatial index has_index_query = f"""SELECT DISTINCT INDEX_TYPE FROM INFORMATION_SCHEMA.STATISTICS WHERE TABLE_NAME = '{table.name}' and COLUMN_NAME = '{column_name}'""" if schema is not None: has_index_query += f""" and TABLE_SCHEMA = '{schema}'""" spatial_index_res = inspector.bind.execute(text(has_index_query)).scalar() spatial_index = str(spatial_index_res).lower() == "spatial" # Set attributes column_info["type"] = Geometry( geometry_type=geometry_type.upper(), srid=srid, spatial_index=spatial_index, nullable=is_nullable, _spatial_index_reflected=True, )
[docs] def before_cursor_execute(conn, cursor, statement, parameters, context, executemany, convert=True): # noqa: D417 """Event handler to cast the parameters properly. Args: convert (bool): Trigger the conversion. """ if convert: if isinstance(parameters, tuple | list): parameters = tuple(x.tobytes() if isinstance(x, memoryview) else x for x in parameters) elif isinstance(parameters, dict): for k in parameters: if isinstance(parameters[k], memoryview): parameters[k] = parameters[k].tobytes() return statement, parameters
[docs] def before_create(table, bind, **kw): """Handle spatial indexes during the before_create event.""" dialect, gis_cols, regular_cols = setup_create_drop(table, bind) # Remove the spatial indexes from the table metadata because they should not be # created during the table.create() step since the associated columns do not exist # at this time. table.info["_after_create_indexes"] = [] current_indexes = set(table.indexes) for idx in current_indexes: for col in table.info["_saved_columns"]: if (_check_spatial_type(col.type, Geometry, dialect)) and col in idx.columns.values(): table.indexes.remove(idx) if idx.name != _spatial_idx_name(table.name, col.name) or not getattr( col.type, "spatial_index", False ): table.info["_after_create_indexes"].append(idx) table.columns = table.info.pop("_saved_columns")
[docs] def after_create(table, bind, **kw): """Handle spatial indexes during the after_create event.""" # Restore original column list including managed Geometry columns dialect = bind.dialect # table.columns = table.info.pop("_saved_columns") for col in table.columns: # Add spatial indices for the Geometry and Geography columns if ( _check_spatial_type(col.type, (Geometry, Geography), dialect) and col.type.spatial_index is True and col.computed is None and not [i for i in table.indexes if col in i.columns.values()] ): sql = f"ALTER TABLE {table.name} ADD SPATIAL INDEX({col.name});" q = text(sql) bind.execute(q) for idx in table.info.pop("_after_create_indexes"): table.indexes.add(idx)
def before_drop(table, bind, **kw): return def after_drop(table, bind, **kw): return _MYSQL_FUNCTIONS = {"ST_AsEWKB": "ST_AsBinary", "ST_SetSRID": "ST_SRID"} def _compiles_mysql(cls, fn): def _compile_mysql(element, compiler, **kw): return f"{fn}({compiler.process(element.clauses, **kw)})" compiles(getattr(functions, cls), "mysql")(_compile_mysql)
[docs] def register_mysql_mapping(mapping): """Register compilation mappings for the given functions. Args: mapping: Should have the following form:: { "function_name_1": "mysql_function_name_1", "function_name_2": "mysql_function_name_2", ... } """ for cls, fn in mapping.items(): _compiles_mysql(cls, fn)
register_mysql_mapping(_MYSQL_FUNCTIONS) def _compile_GeomFromText_MySql(element, compiler, **kw): identifier = "ST_GeomFromText" compiled = compiler.process(element.clauses, **kw) srid = element.type.srid if srid > 0: return f"{identifier}({compiled}, {srid})" else: return f"{identifier}({compiled})" def _compile_GeomFromWKB_MySql(element, compiler, **kw): # Store the SRID clauses = list(element.clauses) try: srid = clauses[1].value except (IndexError, TypeError, ValueError): srid = element.type.srid if kw.get("literal_binds", False): wkb_clause = compile_bin_literal(clauses[0]) prefix = "unhex(" suffix = ")" else: wkb_clause = clauses[0] prefix = "" suffix = "" compiled = compiler.process(wkb_clause, **kw) if srid > 0: return f"{element.identifier}({prefix}{compiled}{suffix}, {srid})" else: return f"{element.identifier}({prefix}{compiled}{suffix})" @compiles(functions.ST_GeomFromText, "mysql") # type: ignore def _MySQL_ST_GeomFromText(element, compiler, **kw): return _compile_GeomFromText_MySql(element, compiler, **kw) @compiles(functions.ST_GeomFromEWKT, "mysql") # type: ignore def _MySQL_ST_GeomFromEWKT(element, compiler, **kw): return _compile_GeomFromText_MySql(element, compiler, **kw) @compiles(functions.ST_GeomFromWKB, "mysql") # type: ignore def _MySQL_ST_GeomFromWKB(element, compiler, **kw): return _compile_GeomFromWKB_MySql(element, compiler, **kw) @compiles(functions.ST_GeomFromEWKB, "mysql") # type: ignore def _MySQL_ST_GeomFromEWKB(element, compiler, **kw): return _compile_GeomFromWKB_MySql(element, compiler, **kw)