Source code for core.model

# encoding: utf-8

import logging
import os
import warnings
from psycopg2.extensions import adapt as sqlescape
from psycopg2.extras import NumericRange
from sqlalchemy import (
    Column,
    create_engine,
    ForeignKey,
    Integer,
    Table,
    text,
)
from sqlalchemy.exc import (
    IntegrityError,
    SAWarning,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import (
    relationship,
    sessionmaker,
)
from sqlalchemy.orm.exc import (
    NoResultFound,
    MultipleResultsFound,
)
from sqlalchemy.sql import (
    compiler,
    select,
)
from sqlalchemy.sql.expression import (
    literal_column,
    table,
)

Base = declarative_base()

from .constants import (
    DataSourceConstants,
    EditionConstants,
    IdentifierConstants,
    LinkRelations,
    MediaTypes,
)
from .. import classifier
from ..util.datetime_helpers import utc_now

[docs]def flush(db): """Flush the database connection unless it's known to already be flushing.""" is_flushing = False if hasattr(db, '_flushing'): # This is a regular database session. is_flushing = db._flushing elif hasattr(db, 'registry'): # This is a flask_scoped_session scoped session. is_flushing = db.registry()._flushing else: logging.error("Unknown database connection type: %r", db) if not is_flushing: db.flush()
[docs]def create(db, model, create_method='', create_method_kwargs=None, **kwargs): kwargs.update(create_method_kwargs or {}) created = getattr(model, create_method, model)(**kwargs) db.add(created) flush(db) return created, True
[docs]def get_one(db, model, on_multiple='error', constraint=None, **kwargs): """Gets an object from the database based on its attributes. :param constraint: A single clause that can be passed into `sqlalchemy.Query.filter` to limit the object that is returned. :return: object or None """ constraint = constraint if 'constraint' in kwargs: constraint = kwargs['constraint'] del kwargs['constraint'] q = db.query(model).filter_by(**kwargs) if constraint is not None: q = q.filter(constraint) try: return q.one() except MultipleResultsFound: if on_multiple == 'error': raise elif on_multiple == 'interchangeable': # These records are interchangeable so we can use # whichever one we want. # # This may be a sign of a problem somewhere else. A # database-level constraint might be useful. q = q.limit(1) return q.one() except NoResultFound: return None
[docs]def get_one_or_create(db, model, create_method='', create_method_kwargs=None, **kwargs): one = get_one(db, model, **kwargs) if one: return one, False else: __transaction = db.begin_nested() try: # These kwargs are supported by get_one() but not by create(). get_one_keys = ['on_multiple', 'constraint'] for key in get_one_keys: if key in kwargs: del kwargs[key] obj = create(db, model, create_method, create_method_kwargs, **kwargs) __transaction.commit() return obj except IntegrityError as e: logging.info( "INTEGRITY ERROR on %r %r, %r: %r", model, create_method_kwargs, kwargs, e) __transaction.rollback() return db.query(model).filter_by(**kwargs).one(), False
[docs]def numericrange_to_string(r): """Helper method to convert a NumericRange to a human-readable string.""" if not r: return "" lower = r.lower upper = r.upper if upper is None and lower is None: return "" if lower and upper is None: return str(lower) if upper and lower is None: return str(upper) if not r.upper_inc: upper -= 1 if not r.lower_inc: lower += 1 if upper == lower: return str(lower) return "%s-%s" % (lower,upper)
[docs]def numericrange_to_tuple(r): """Helper method to normalize NumericRange into a tuple.""" if r is None: return (None, None) lower = r.lower upper = r.upper if lower and not r.lower_inc: lower += 1 if upper and not r.upper_inc: upper -= 1 return lower, upper
[docs]def tuple_to_numericrange(t): """Helper method to convert a tuple to an inclusive NumericRange.""" if not t: return None return NumericRange(t[0], t[1], '[]')
[docs]class PresentationCalculationPolicy(object): """Which parts of the Work or Edition's presentation are we actually looking to update? """ DEFAULT_LEVELS = 3 DEFAULT_THRESHOLD = 0.5 DEFAULT_CUTOFF = 1000 def __init__(self, choose_edition=True, set_edition_metadata=True, classify=True, choose_summary=True, calculate_quality=True, choose_cover=True, regenerate_opds_entries=False, regenerate_marc_record=False, update_search_index=False, verbose=True, equivalent_identifier_levels=DEFAULT_LEVELS, equivalent_identifier_threshold=DEFAULT_THRESHOLD, equivalent_identifier_cutoff=DEFAULT_CUTOFF, ): """Constructor. :param choose_edition: Should a new presentation edition be chosen/created, or should we assume the old one is fine? :param set_edition_metadata: Should we set new values for basic metadata such as title? :param classify: Should we reconsider which Genres under which a Work should be filed? :param choose_summary: Should we reconsider which of the available summaries is the best? :param calculate_quality: Should we recalculate the overall quality of the Work? :param choose_cover: Should we reconsider which of the available cover images is the best? :param regenerate_opds_entries: Should we recreate the OPDS entries for this Work? :param regenerate_marc_record: Should we regenerate the MARC record for this Work? :param update_search_index: Should we reindex this Work's entry in the search index? :param verbose: Should we print out information about the work we're doing? :param equivalent_identifier_levels: When determining which identifiers refer to this Work (used when gathering classifications, cover images, etc.), how many levels of equivalency should we go down? E.g. for one level of equivalency we will go from a proprietary vendor ID to the equivalent ISBN. :param equivalent_identifier_threshold: When determining which identifiers refer to this Work, what is the probability threshold for 'equivalency'? E.g. a value of 1 means that we will not count two identifiers as equivalent unless we are absolutely certain. :param equivalent_identifier_cutoff: When determining which identifiers refer to this work, how many Identifiers are enough? Gathering _all_ the identifiers that identify an extremely popular work can take an extraordinarily long time for very little payoff, so it's useful to have a cutoff. The cutoff is applied _per level_, so the total maximum number of equivalent identifiers is equivalent_identifier_cutoff * equivalent_identifier_levels. """ self.choose_edition = choose_edition self.set_edition_metadata = set_edition_metadata self.classify = classify self.choose_summary=choose_summary self.calculate_quality=calculate_quality self.choose_cover = choose_cover # We will regenerate OPDS entries if any of the metadata # changes, but if regenerate_opds_entries is True we will # _always_ do so. This is so we can regenerate _all_ the OPDS # entries if the OPDS presentation algorithm changes. # The same is true for the MARC records, except that they will # never be generated unless a MARC organization code is set # in a sitewide configuration setting. self.regenerate_opds_entries = regenerate_opds_entries self.regenerate_marc_record = regenerate_marc_record # Similarly for update_search_index. self.update_search_index = update_search_index self.verbose = verbose self.equivalent_identifier_levels = equivalent_identifier_levels self.equivalent_identifier_threshold = equivalent_identifier_threshold self.equivalent_identifier_cutoff = equivalent_identifier_cutoff
[docs] @classmethod def recalculate_everything(cls): """A PresentationCalculationPolicy that always recalculates everything, even when it doesn't seem necessary. """ return PresentationCalculationPolicy( regenerate_opds_entries=True, regenerate_marc_record=True, update_search_index=True, )
[docs] @classmethod def reset_cover(cls): """A PresentationCalculationPolicy that only resets covers (including updating cached entries, if necessary) without impacting any other metadata. """ return cls( choose_cover=True, choose_edition=False, set_edition_metadata=False, classify=False, choose_summary=False, calculate_quality=False )
[docs]def dump_query(query): dialect = query.session.bind.dialect statement = query.statement comp = compiler.SQLCompiler(dialect, statement) enc = dialect.encoding params = {} for k,v in list(comp.params.items()): if isinstance(v, str): v = v.encode(enc) params[k] = sqlescape(v) return (comp.string.encode(enc) % params).decode(enc)
DEBUG = False
[docs]class SessionManager(object): # A function that calculates recursively equivalent identifiers # is also defined in SQL. RECURSIVE_EQUIVALENTS_FUNCTION = 'recursive_equivalents.sql' engine_for_url = {}
[docs] @classmethod def engine(cls, url=None): url = url or Configuration.database_url() # Default sqlalchemy QueuePool pool_size is 10 and max_overflow is 20. This triples it. return create_engine(url, echo=DEBUG, pool_size=30, max_overflow=60, pool_recycle=60*30)
[docs] @classmethod def sessionmaker(cls, url=None, session=None): if not (url or session): url = Configuration.database_url() if url: bind_obj = cls.engine(url) elif session: bind_obj = session.get_bind() if not os.environ.get('TESTING'): # If a factory is being created from a session in test mode, # use the same Connection for all of the tests so objects can # be accessed. Otherwise, bind against an Engine object. bind_obj = bind_obj.engine return sessionmaker(bind=bind_obj)
[docs] @classmethod def resource_directory(cls): """The directory containing SQL files used in database setup.""" base_path = os.path.split(__file__)[0] return os.path.join(base_path, "files")
[docs] @classmethod def initialize(cls, url, initialize_data=True, initialize_schema=True): """Initialize the database. This includes the schema, the custom functions, and the initial content. """ if url in cls.engine_for_url: engine = cls.engine_for_url[url] return engine engine = cls.engine(url) if initialize_schema: cls.initialize_schema(engine) connection = engine.connect() # Check if the recursive equivalents function exists already. query = select( [literal_column('proname')] ).select_from( table('pg_proc') ).where( literal_column('proname')=='fn_recursive_equivalents' ) result = connection.execute(query) result = list(result) # If it doesn't, create it. if not result and initialize_data: resource_file = os.path.join( cls.resource_directory(), cls.RECURSIVE_EQUIVALENTS_FUNCTION ) if not os.path.exists(resource_file): raise IOError("Could not load recursive equivalents function from %s: file does not exist." % resource_file) sql = open(resource_file).read() connection.execute(sql) if initialize_data: session = Session(connection) cls.initialize_data(session) session.close() if connection: connection.close() if initialize_schema and initialize_data: # Only cache the engine if all initialization has been performed. # # Some pieces of code (e.g. the script that runs # migrations) have a legitimate need to bypass some of the # initialization, but normal operation of the site # requires that everything be initialized. # # Until someone tells this method to initialize # everything, we can't short-circuit this method with a # cache. cls.engine_for_url[url] = engine return engine
[docs] @classmethod def initialize_schema(cls, engine): """Initialize the database schema.""" # Use SQLAlchemy to create all the tables. to_create = [ table_obj for name, table_obj in list(Base.metadata.tables.items()) if not name.startswith('mv_') ] Base.metadata.create_all(engine, tables=to_create)
[docs] @classmethod def session(cls, url, initialize_data=True, initialize_schema=True): with warnings.catch_warnings(): warnings.simplefilter("ignore", category=SAWarning) engine = cls.initialize( url, initialize_data=initialize_data, initialize_schema=initialize_schema ) return Session(engine)
[docs] @classmethod def initialize_data(cls, session, set_site_configuration=True): # Create initial content. from .datasource import DataSource from .classification import Genre from .licensing import DeliveryMechanism list(DataSource.well_known_sources(session)) # Load all existing Genre objects. Genre.populate_cache(session) # Create any genres not in the database. for g in list(classifier.genres.values()): # TODO: On the very first startup this is rather expensive # because the cache is invalidated every time a Genre is # created, then populated the next time a Genre is looked # up. This wouldn't be a big problem, but this also happens # on setup for the unit tests. Genre.lookup(session, g, autocreate=True) # Make sure that the mechanisms fulfillable by the default # client are marked as such. for content_type, drm_scheme in DeliveryMechanism.default_client_can_fulfill_lookup: mechanism, is_new = DeliveryMechanism.lookup( session, content_type, drm_scheme ) mechanism.default_client_can_fulfill = True # If there is currently no 'site configuration change' # Timestamp in the database, create one. timestamp, is_new = get_one_or_create( session, Timestamp, collection=None, service=Configuration.SITE_CONFIGURATION_CHANGED, create_method_kwargs=dict(finish=utc_now()) ) if is_new: site_configuration_has_changed(session) session.commit() # Return a potentially-new Session object in case # it was updated by cls.update_timestamps_table return session
[docs]def production_session(initialize_data=True): url = Configuration.database_url() if url.startswith('"'): url = url[1:] logging.debug("Database url: %s", url) _db = SessionManager.session(url, initialize_data=initialize_data) # The first thing to do after getting a database connection is to # set up the logging configuration. # # If called during a unit test, this will configure logging # incorrectly, but 1) this method isn't normally called during # unit tests, and 2) package_setup() will call initialize() again # with the right arguments. from ..log import LogConfiguration LogConfiguration.initialize(_db) return _db
from .admin import ( Admin, AdminRole, ) from .coverage import ( BaseCoverageRecord, CoverageRecord, Timestamp, WorkCoverageRecord, ) from .cachedfeed import ( CachedFeed, WillNotGenerateExpensiveFeed, CachedMARCFile, ) from .circulationevent import CirculationEvent from .classification import ( Classification, Genre, Subject, ) from .collection import ( Collection, CollectionIdentifier, CollectionMissing, collections_identifiers, ) from .configuration import ( ConfigurationSetting, ExternalIntegration, ExternalIntegrationLink, ) from .complaint import Complaint from .contributor import ( Contribution, Contributor, ) from .credential import ( Credential, DelegatedPatronIdentifier, DRMDeviceIdentifier, ) from .customlist import ( CustomList, CustomListEntry, ) from .datasource import DataSource from .edition import Edition from .hasfulltablecache import HasFullTableCache from .identifier import ( Equivalency, Identifier, ) from .integrationclient import IntegrationClient from .library import Library from .licensing import ( DeliveryMechanism, License, LicensePool, LicensePoolDeliveryMechanism, PolicyException, RightsStatus, ) from .measurement import Measurement from .patron import ( Annotation, Hold, Loan, LoanAndHoldMixin, Patron, PatronProfileStorage, ) from .listeners import * from .resource import ( Hyperlink, Representation, Resource, ResourceTransformation, ) from .work import ( Work, WorkGenre, )