# Copyright 2013 IBM Corp. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. from oslo_db import exception as db_exc from oslo_db.sqlalchemy import utils as sqlalchemyutils from oslo_log import log as logging from oslo_utils import versionutils from nova.db.api import api as api_db_api from nova.db.api import models as api_models from nova import exception from nova import objects from nova.objects import base from nova.objects import fields KEYPAIR_TYPE_SSH = 'ssh' KEYPAIR_TYPE_X509 = 'x509' LOG = logging.getLogger(__name__) @api_db_api.context_manager.reader def _get_from_db(context, user_id, name=None, limit=None, marker=None): query = context.session.query(api_models.KeyPair).\ filter(api_models.KeyPair.user_id == user_id) if name is not None: db_keypair = query.filter(api_models.KeyPair.name == name).\ first() if not db_keypair: raise exception.KeypairNotFound(user_id=user_id, name=name) return db_keypair marker_row = None if marker is not None: marker_row = context.session.query(api_models.KeyPair).\ filter(api_models.KeyPair.name == marker).\ filter(api_models.KeyPair.user_id == user_id).first() if not marker_row: raise exception.MarkerNotFound(marker=marker) query = sqlalchemyutils.paginate_query( query, api_models.KeyPair, limit, ['name'], marker=marker_row) return query.all() @api_db_api.context_manager.reader def _get_count_from_db(context, user_id): return context.session.query(api_models.KeyPair).\ filter(api_models.KeyPair.user_id == user_id).\ count() @api_db_api.context_manager.writer def _create_in_db(context, values): kp = api_models.KeyPair() kp.update(values) try: kp.save(context.session) except db_exc.DBDuplicateEntry: raise exception.KeyPairExists(key_name=values['name']) return kp @api_db_api.context_manager.writer def _destroy_in_db(context, user_id, name): result = context.session.query(api_models.KeyPair).\ filter_by(user_id=user_id).\ filter_by(name=name).\ delete() if not result: raise exception.KeypairNotFound(user_id=user_id, name=name) # TODO(berrange): Remove NovaObjectDictCompat @base.NovaObjectRegistry.register class KeyPair(base.NovaPersistentObject, base.NovaObject, base.NovaObjectDictCompat): # Version 1.0: Initial version # Version 1.1: String attributes updated to support unicode # Version 1.2: Added keypair type # Version 1.3: Name field is non-null # Version 1.4: Add localonly flag to get_by_name() VERSION = '1.4' fields = { 'id': fields.IntegerField(), 'name': fields.StringField(nullable=False), 'user_id': fields.StringField(nullable=True), 'fingerprint': fields.StringField(nullable=True), 'public_key': fields.StringField(nullable=True), 'type': fields.StringField(nullable=False), } def obj_make_compatible(self, primitive, target_version): super(KeyPair, self).obj_make_compatible(primitive, target_version) target_version = versionutils.convert_version_to_tuple(target_version) if target_version < (1, 2) and 'type' in primitive: del primitive['type'] @staticmethod def _from_db_object(context, keypair, db_keypair): ignore = {'deleted': False, 'deleted_at': None} for key in keypair.fields: if key in ignore and not hasattr(db_keypair, key): keypair[key] = ignore[key] else: keypair[key] = db_keypair[key] keypair._context = context keypair.obj_reset_changes() return keypair @staticmethod def _get_from_db(context, user_id, name): return _get_from_db(context, user_id, name=name) @staticmethod def _destroy_in_db(context, user_id, name): return _destroy_in_db(context, user_id, name) @staticmethod def _create_in_db(context, values): return _create_in_db(context, values) # TODO(stephenfin): Remove the 'localonly' parameter in v2.0 @base.remotable_classmethod def get_by_name(cls, context, user_id, name, localonly=False): if localonly: # There is no longer a "local" (main) table for keypairs, so this # will always return nothing now raise exception.KeypairNotFound(user_id=user_id, name=name) db_keypair = cls._get_from_db(context, user_id, name) return cls._from_db_object(context, cls(), db_keypair) @base.remotable_classmethod def destroy_by_name(cls, context, user_id, name): cls._destroy_in_db(context, user_id, name) @base.remotable def create(self): if self.obj_attr_is_set('id'): raise exception.ObjectActionError( action='create', reason='already created', ) self._create() def _create(self): updates = self.obj_get_changes() db_keypair = self._create_in_db(self._context, updates) self._from_db_object(self._context, self, db_keypair) @base.remotable def destroy(self): self._destroy_in_db(self._context, self.user_id, self.name) @base.NovaObjectRegistry.register class KeyPairList(base.ObjectListBase, base.NovaObject): # Version 1.0: Initial version # KeyPair <= version 1.1 # Version 1.1: KeyPair <= version 1.2 # Version 1.2: KeyPair <= version 1.3 # Version 1.3: Add new parameters 'limit' and 'marker' to get_by_user() VERSION = '1.3' fields = { 'objects': fields.ListOfObjectsField('KeyPair'), } @staticmethod def _get_from_db(context, user_id, limit, marker): return _get_from_db(context, user_id, limit=limit, marker=marker) @staticmethod def _get_count_from_db(context, user_id): return _get_count_from_db(context, user_id) @base.remotable_classmethod def get_by_user(cls, context, user_id, limit=None, marker=None): api_db_keypairs = cls._get_from_db( context, user_id, limit=limit, marker=marker) return base.obj_make_list( context, cls(context), objects.KeyPair, api_db_keypairs, ) @base.remotable_classmethod def get_count_by_user(cls, context, user_id): return cls._get_count_from_db(context, user_id)