SQLAlchemy: replacing object with a new one, following defaults - sqlalchemy

I want to create a new instance of an SQLAlchemy object, so that fields are filled with default values, but I want to commit that to the database generating an UPDATE to a row that already exists with the same primary key, effectively resetting it to the default values. Is there any simple way to do that?

I have tried to do that and failed, because SQLAlchemy session tracks state of objects. So there is no easy way to make session to track new object as persistent one.
But you want to reset object to default, do you? There is a simple way to do that:
from sqlalchemy.ext.declarative import declarative_base
class Base(object):
def reset(self):
for name, column in self.__class__.__table__.columns.items():
if column.default is not None:
setattr(self, name, column.default.execute())
Base = declarative_base(bind=engine, cls=Base)
This adds reset method to all your model classes.
Here is the complete working example to fiddle with:
import os
from datetime import datetime
from sqlalchemy import create_engine
from sqlalchemy import Column, Integer, String, DateTime
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import functions
here = os.path.abspath(os.path.dirname(__file__))
engine = create_engine('sqlite:///%s/db.sqlite' % here, echo=True)
Session = sessionmaker(bind=engine)
class Base(object):
def reset(self):
for name, column in self.__class__.__table__.columns.items():
if column.default is not None:
setattr(self, name, column.default.execute())
Base = declarative_base(bind=engine, cls=Base)
class Thing(Base):
__tablename__ = 'things'
id = Column(Integer, primary_key=True)
value = Column(String(255), default='default')
ts1 = Column(DateTime, default=datetime.now)
ts2 = Column(DateTime, default=functions.now())
def __repr__(self):
return '<Thing(id={0.id!r}, value={0.value!r}, ' \
'ts1={0.ts1!r}, ts2={0.ts2!r})>'.format(self)
if __name__ == '__main__':
Base.metadata.drop_all()
Base.metadata.create_all()
print("---------------------------------------")
print("Create a new thing")
print("---------------------------------------")
session = Session()
thing = Thing(
value='some value',
ts1=datetime(2014, 1, 1),
ts2=datetime(2014, 2, 2),
)
session.add(thing)
session.commit()
session.close()
print("---------------------------------------")
print("Quering it from DB")
print("---------------------------------------")
session = Session()
thing = session.query(Thing).filter(Thing.id == 1).one()
print(thing)
session.close()
print("---------------------------------------")
print("Reset it to default")
print("---------------------------------------")
session = Session()
thing = session.query(Thing).filter(Thing.id == 1).one()
thing.reset()
session.commit()
session.close()
print("---------------------------------------")
print("Quering it from DB")
print("---------------------------------------")
session = Session()
thing = session.query(Thing).filter(Thing.id == 1).one()
print(thing)
session.close()

Is there any simple way to do that?
Upon further consideration, not really. The cleanest way will be to define your defaults in __init__. The constructor is never called when fetching objects from the DB, so it's perfectly safe. You can also use backend functions such as current_timestamp().
class MyObject(Base):
id = Column(sa.Integer, primary_key=True)
column1 = Column(sa.String)
column2 = Column(sa.Integer)
columnN = Column(sa.String)
updated = Column(sa.DateTime)
def __init__(self, **kwargs):
kwargs.setdefault('column1', 'default value')
kwargs.setdefault('column2', 123)
kwargs.setdefault('columnN', None)
kwargs.setdefault('updated', sa.func.current_timestamp())
super(MyObject, self).__init__(**kwargs)
default_obj = MyObject()
default_obj.id = old_id
session.merge(default_obj)
session.commit()

Related

SQLAlchemy Table classes and imports

I started a project using PostgreSQL and SQLAlchemy. Since i'm not a experienced programmer(just started using classes) and also quite new to databases i noticed some workflows i don't really understand.
What i understand up till now from classes is the following workflow:
# filename.py
class ClassName():
def __init__(self):
# do something
def some_funcion(self, var1, var2):
# do something with parameters
---------------------------------------
# main.py
from filename import ClassName
par1 = ...
par2 = ...
a = ClassName()
b = a.some_function(par1, par2)
Now i am creating tables from classes:
# base.py
from sqlalchemy.orm import declarative_base
Base = declarative_base()
# tables.py
from base import Base
from sqlalchemy import Column
from sqlalchemy import Integer, String
class A(Base):
__tablename__ = "a"
a_id = Column(Integer, primary_key=True)
a_column = Column(String(30))
class B(Base):
__tablename__ = "b"
b_id = Column(Integer, primary_key=True)
b_column = Column(String(30))
and
import typing
from base import Base
from sqlalchemy import create_engine
from sqlalchemy import MetaData
from sqlalchemy import Table
from sqlalchemy.orm import sessionmaker
from tables import A, B
metadata_obj = MetaData()
def create_tables(engine):
session = sessionmaker()
session.configure(bind=engine)
Base.metadata.create_all(bind=engine)
a = Table("a", metadata_obj, autoload_with=engine)
b = Table("b", metadata_obj, autoload_with=engine)
return(a, b) # not sure return is needed
if __name__ == "__main__":
username = "username"
password = "AtPasswordHere!"
dbname = "dbname"
url = "postgresql://" + username + ":" + password + "#localhost/" + dbname
engine = create_engine(url, echo=True, future=True)
a, b = create_tables(engine)
Everything works fine in that it creates Table A and Table B in the database. The point i don't understand is the following:
Both my IDE(pyflake) and LGTM complain 'Tables. ... imported but not used'. (EDIT i understand why it complains in the way it is not the normal Class flow. It is mor about Why it is not the normal class workflow)
Is this normal behavior for this usecase? I only see examples that make use of the above workflow
Are there better methods to create the same results (but without the warnings)
If this is the normal behavior: Is there an explanation for this? I didn't read it anywhere.

How to use nested pydantic models for sqlalchemy in a flexible way

from fastapi import Depends, FastAPI, HTTPException, Body, Request
from sqlalchemy import create_engine, Boolean, Column, ForeignKey, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker, relationship
from sqlalchemy.inspection import inspect
from typing import List, Optional
from pydantic import BaseModel
import json
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
app = FastAPI()
# sqlalchemy models
class RootModel(Base):
__tablename__ = "root_table"
id = Column(Integer, primary_key=True, index=True)
someRootText = Column(String)
subData = relationship("SubModel", back_populates="rootData")
class SubModel(Base):
__tablename__ = "sub_table"
id = Column(Integer, primary_key=True, index=True)
someSubText = Column(String)
root_id = Column(Integer, ForeignKey("root_table.id"))
rootData = relationship("RootModel", back_populates="subData")
# pydantic models/schemas
class SchemaSubBase(BaseModel):
someSubText: str
class Config:
orm_mode = True
class SchemaSub(SchemaSubBase):
id: int
root_id: int
class Config:
orm_mode = True
class SchemaRootBase(BaseModel):
someRootText: str
subData: List[SchemaSubBase] = []
class Config:
orm_mode = True
class SchemaRoot(SchemaRootBase):
id: int
class Config:
orm_mode = True
class SchemaSimpleBase(BaseModel):
someRootText: str
class Config:
orm_mode = True
class SchemaSimple(SchemaSimpleBase):
id: int
class Config:
orm_mode = True
Base.metadata.create_all(bind=engine)
# database functions (CRUD)
def db_add_simple_data_pydantic(db: Session, root: SchemaRootBase):
db_root = RootModel(**root.dict())
db.add(db_root)
db.commit()
db.refresh(db_root)
return db_root
def db_add_nested_data_pydantic_generic(db: Session, root: SchemaRootBase):
# this fails:
db_root = RootModel(**root.dict())
db.add(db_root)
db.commit()
db.refresh(db_root)
return db_root
def db_add_nested_data_pydantic(db: Session, root: SchemaRootBase):
# start: hack: i have to manually generate the sqlalchemy model from the pydantic model
root_dict = root.dict()
sub_dicts = []
# i have to remove the list form root dict in order to fix the error from above
for key in list(root_dict):
if isinstance(root_dict[key], list):
sub_dicts = root_dict[key]
del root_dict[key]
# now i can do it
db_root = RootModel(**root_dict)
for sub_dict in sub_dicts:
db_root.subData.append(SubModel(**sub_dict))
# end: hack
db.add(db_root)
db.commit()
db.refresh(db_root)
return db_root
def db_add_nested_data_nopydantic(db: Session, root):
print(root)
sub_dicts = root.pop("subData")
print(sub_dicts)
db_root = RootModel(**root)
for sub_dict in sub_dicts:
db_root.subData.append(SubModel(**sub_dict))
db.add(db_root)
db.commit()
db.refresh(db_root)
# problem
"""
if I would now "return db_root", the answer would be of this:
{
"someRootText": "string",
"id": 24
}
and not containing "subData"
therefore I have to do the following.
Why?
"""
from sqlalchemy.orm import joinedload
db_root = (
db.query(RootModel)
.options(joinedload(RootModel.subData))
.filter(RootModel.id == db_root.id)
.all()
)[0]
return db_root
# Dependency
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
#app.post("/addNestedModel_pydantic_generic", response_model=SchemaRootBase)
def addSipleModel_pydantic_generic(root: SchemaRootBase, db: Session = Depends(get_db)):
data = db_add_simple_data_pydantic(db=db, root=root)
return data
#app.post("/addSimpleModel_pydantic", response_model=SchemaSimpleBase)
def add_simple_data_pydantic(root: SchemaSimpleBase, db: Session = Depends(get_db)):
data = db_add_simple_data_pydantic(db=db, root=root)
return data
#app.post("/addNestedModel_nopydantic")
def add_nested_data_nopydantic(root=Body(...), db: Session = Depends(get_db)):
data = db_add_nested_data_nopydantic(db=db, root=root)
return data
#app.post("/addNestedModel_pydantic", response_model=SchemaRootBase)
def add_nested_data_pydantic(root: SchemaRootBase, db: Session = Depends(get_db)):
data = db_add_nested_data_pydantic(db=db, root=root)
return data
Description
My Question is:
How to make nested sqlalchemy models from nested pydantic models (or python dicts) in a generic way and write them to the database in "one shot".
My example model is called RootModel and has a list of submodels called "sub models" in subData key.
Please see above for pydantic and sqlalchemy definitions.
Example:
The user provides a nested json string:
{
"someRootText": "string",
"subData": [
{
"someSubText": "string"
}
]
}
Open the browser and call the endpoint /docs.
You can play around with all endpoints and POST the json string from above.
/addNestedModel_pydantic_generic
When you call the endpoint /addNestedModel_pydantic_generic it will fail, because sqlalchemy cannot create the nested model from pydantic nested model directly:
AttributeError: 'dict' object has no attribute '_sa_instance_state'
​/addSimpleModel_pydantic
With a non-nested model it works.
The remaining endpoints are showing "hacks" to solve the problem of nested models.
/addNestedModel_pydantic
In this endpoint is generate the root model and andd the submodels with a loop in a non-generic way with pydantic models.
/addNestedModel_pydantic
In this endpoint is generate the root model and andd the submodels with a loop in a non-generic way with python dicts.
My solutions are only hacks, I want a generic way to create nested sqlalchemy models either from pydantic (preferred) or from a python dict.
Environment
OS: Windows,
FastAPI Version : 0.61.1
Python version: Python 3.8.5
sqlalchemy: 1.3.19
pydantic : 1.6.1
I haven't found a nice built-in way to do this within pydantic/SQLAlchemy. How I solved it: I gave every nested pydantic model a Meta class containing the corresponding SQLAlchemy model. Like so:
from pydantic import BaseModel
from models import ChildDBModel, ParentDBModel
class ChildModel(BaseModel):
some_attribute: str = 'value'
class Meta:
orm_model = ChildDBModel
class ParentModel(BaseModel):
child: SubModel
That allowed me to write a generic function that loops through the pydantic object and transforms submodels into SQLAlchemy models:
def is_pydantic(obj: object):
"""Checks whether an object is pydantic."""
return type(obj).__class__.__name__ == "ModelMetaclass"
def parse_pydantic_schema(schema):
"""
Iterates through pydantic schema and parses nested schemas
to a dictionary containing SQLAlchemy models.
Only works if nested schemas have specified the Meta.orm_model.
"""
parsed_schema = dict(schema)
for key, value in parsed_schema.items():
try:
if isinstance(value, list) and len(value):
if is_pydantic(value[0]):
parsed_schema[key] = [schema.Meta.orm_model(**schema.dict()) for schema in value]
else:
if is_pydantic(value):
parsed_schema[key] = value.Meta.orm_model(**value.dict())
except AttributeError:
raise AttributeError("Found nested Pydantic model but Meta.orm_model was not specified.")
return parsed_schema
The parse_pydantic_schema function returns a dictionary representation of the pydantic model where submodels are substituted by the corresponding SQLAlchemy model specified in Meta.orm_model. You can use this return value to create the parent SQLAlchemy model in one go:
parsed_schema = parse_pydantic_schema(parent_model) # parent_model is an instance of pydantic ParentModel
new_db_model = ParentDBModel(**parsed_schema)
# do your db actions/commit here
If you want you can even extend this to also automatically create the parent model, but that requires you to also specify the Meta.orm_model for all pydantic models.
Using a validators is a lot simpler:
SQLAlchemy models.py:
class ChildModel(Base):
__tablename__ = "Child"
name: str = Column(Unicode(255), nullable=False, primary_key=True)
class ParentModel(Base):
__tablename__ = "Parent"
some_attribute: str = Column(Unicode(255))
children = relationship("Child", lazy="joined", cascade="all, delete-orphan")
#validates("children")
def adjust_children(self, _, value) -> ChildModel:
"""Instantiate Child object if it is only plain string."""
if value and isinstance(value, str):
return ChildModel(some_attribute=value)
return value
Pydantic schema.py:
class Parent(BaseModel):
"""Model used for parents."""
some_attribute: str
children: List[str] = Field(example=["foo", "bar"], default=[])
#validator("children", pre=True)
def adjust_children(cls, children):
"""Convert to plain string if it is a Child object."""
if children and not isinstance(next(iter(children), None), str):
return [child["name"] for child in children]
return children
Nice function #dann, for more than two level of nesting you can use this recursive function :
def pydantic_to_sqlalchemy_model(schema):
"""
Iterates through pydantic schema and parses nested schemas
to a dictionary containing SQLAlchemy models.
Only works if nested schemas have specified the Meta.orm_model.
"""
parsed_schema = dict(schema)
for key, value in parsed_schema.items():
try:
if isinstance(value, list) and len(value) and is_pydantic(value[0]):
parsed_schema[key] = [
item.Meta.orm_model(**pydantic_to_sqlalchemy_model(item))
for item in value
]
elif is_pydantic(value):
parsed_schema[key] = value.Meta.orm_model(
**pydantic_to_sqlalchemy_model(value)
)
except AttributeError:
raise AttributeError(
f"Found nested Pydantic model in {schema.__class__} but Meta.orm_model was not specified."
)
return parsed_schema
Use it sparingly ! is you have a cyclical nesting it will loop forever.
And then call you data transformer like this :
def create_parent(db: Session, parent: Parent_pydantic_schema):
db_parent = Parent_model(**pydantic_to_sqlalchemy_model(intent))
db.add(db_parent)
db.commit()
return db_parent

How do I write SQLAlchemy test fixtures for FastAPI applications

I am writing a FastAPI application that uses a SQLAlchemy database. I have copied the example from the FastAPI documentation, simplifying the database schema for concisions' sake. The complete source is at the bottom of this post.
This works. I can run it with uvicorn sql_app.main:app and interact with the database via the Swagger docs. When it runs it creates a test.db in the working directory.
Now I want to add a unit test. Something like this.
from fastapi import status
from fastapi.testclient import TestClient
from pytest import fixture
from main import app
#fixture
def client() -> TestClient:
return TestClient(app)
def test_fast_sql(client: TestClient):
response = client.get("/users/")
assert response.status_code == status.HTTP_200_OK
assert response.json() == []
Using the source code below, this takes the test.db in the working directory as the database. Instead I want to create a new database for every unit test that is deleted at the end of the test.
I could put the global database.engine and database.SessionLocal inside an object that is created at runtime, like so:
class UserDatabase:
def __init__(self, directory: Path):
directory.mkdir(exist_ok=True, parents=True)
sqlalchemy_database_url = f"sqlite:///{directory}/store.db"
self.engine = create_engine(
sqlalchemy_database_url, connect_args={"check_same_thread": False}
)
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
models.Base.metadata.create_all(bind=self.engine)
but I don't know how to make that work with main.get_db, since the Depends(get_db) logic ultimately assumes database.engine and database.SessionLocal are available globally.
I'm used to working with Flask, whose unit testing facilities handle all this for you. I don't know how to write it myself. Can someone show me the minimal changes I'd have to make in order to generate a new database for each unit test in this framework?
The complete source of the simplified FastAPI/SQLAlchemy app is as follows.
database.py
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
models.py
from sqlalchemy import Column, Integer, String
from database import Base
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
name = Column(String)
age = Column(Integer)
schemas.py
from pydantic import BaseModel
class UserBase(BaseModel):
name: str
age: int
class UserCreate(UserBase):
pass
class User(UserBase):
id: int
class Config:
orm_mode = True
crud.py
from sqlalchemy.orm import Session
import schemas
import models
def get_user(db: Session, user_id: int):
return db.query(models.User).filter(models.User.id == user_id).first()
def get_users(db: Session, skip: int = 0, limit: int = 100):
return db.query(models.User).offset(skip).limit(limit).all()
def create_user(db: Session, user: schemas.UserCreate):
db_user = models.User(name=user.name, age=user.age)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
main.py
from typing import List
from fastapi import Depends, FastAPI, HTTPException
from sqlalchemy.orm import Session
import schemas
import models
import crud
from database import SessionLocal, engine
models.Base.metadata.create_all(bind=engine)
app = FastAPI()
# Dependency
def get_db():
try:
db = SessionLocal()
yield db
finally:
db.close()
#app.post("/users/", response_model=schemas.User)
def create_user(user: schemas.UserCreate, db: Session = Depends(get_db)):
return crud.create_user(db=db, user=user)
#app.get("/users/", response_model=List[schemas.User])
def read_users(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
users = crud.get_users(db, skip=skip, limit=limit)
return users
#app.get("/users/{user_id}", response_model=schemas.User)
def read_user(user_id: int, db: Session = Depends(get_db)):
db_user = crud.get_user(db, user_id=user_id)
if db_user is None:
raise HTTPException(status_code=404, detail="User not found")
return db_user
You need to override your get_db dependency in your tests, see these docs.
Something like this for your fixture:
#fixture
def db_fixture() -> Session:
raise NotImplementError() # Make this return your temporary session
#fixture
def client(db_fixture) -> TestClient:
def _get_db_override():
return db_fixture
app.dependency_overrides[get_db] = _get_db_override
return TestClient(app)

How to make the id auto increasing by 2 in the Model of Django?

Since the auto_increment setting in the MySQL is for the global, which cannot be set to a specific table?
I'm considering if it's possible to make the id auto increasing by 2 in the Model of Django?
models.py
class Video(models.Model):
name = model.CharField(max_length=100, default='')
upload_time = models.DateTimeField(blank=True, null=True)
def __str__(self):
return self.name
What should I do? Thanks for ur help.
You could do it my overriding save() method of your model as
from django.db.models import Max, F
class Video(models.Model):
id = models.AutoField(primary_key=True)
name = models.CharField(max_length=100, default='')
upload_time = models.DateTimeField(blank=True, null=True)
def save(self, *args, **kwargs):
if not self.pk:
max = Video.objects.aggregate(max=Max(F('id')))['max']
self.id = max + 2 if max else 1 # if the DB is empty
super().save(*args, **kwargs)
def __str__(self):
return self.name
the correct way is to change your mysql server settings
check this out: auto_increment_increment
Possible Solutions:
Assume I have a Model of Customer.
Customer.objects.order_by('primay_key_id ').last().primay_key_id + 2)
primay_key_id = models.IntegerField(default=(Customer.objects.order_by('primay_key_id ').last().primay_key_id + 2),primary_key=True)
or
from django.db import transaction
#Uncomment Lines for Django version less than 2.0
def save(self):
"Get last value of Code and Number from database, and increment before save"
#with transaction.atomic():
#top = Customer.objects.select_for_update(nowait=True).order_by('-customer_customerid')[0] #Ensures Lock on Database
top = Customer.objects.order_by('-id')[0]
self.id = top.id + 1
super(Customer, self).save()
The Above Code would not have a Concurrency Issue for Django 2.0 as:
As of Django 2.0, related rows are locked by default (not sure what the behaviour was before) and the rows to lock can be specified in the same style as select_related using the of parameter!
For Lower Versions, you need to be atomic!
or
from django.db import transaction
def increment():
with transaction.atomic():
ids = Customer.objects.all()
length = len(ids)-1
if(length<=0): #Changed to Handle empty Entries
return 1
else:
id = ids[length].customer_customerid
return id+2
or
from django.db import transaction
def increment():
with transaction.atomic():
return Customer.objects.select_for_update(nowait=True).order_by('-customer_customerid')[0] #Ensures Atomic Approach!
and set primary key in model to Integer Field and on every new entry primary_key_field=increment() Like
and then in your Models.py
set the Primary_Key to:
import increment()
primay_key_id = models.IntegerField(default=increment(),primary_key=True)

SQLAlchemy session voes in unittest

I've just started using SQLAlchemy a few days ago and right now I'm stuck with a problem that I hope anyone can shed some light on before I loose all my hair.
When I run a unittest, see snippet below, only the first test in the sequence is passing. The test testPhysicalPrint works just fine, but testRecordingItem fails with NoResultFound exception - No row was found for one(). But if I remove testPhysicalPrint from the test class, then testRecordingItem works.
I assume that the problem has something to do with the session, but I can't really get a grip of it.
In case anyone wonders, the setup is as follows:
Python 3.1 (Ubuntu 10.04 package)
SQLAlchemy 0.7.2 (easy_install:ed)
PostgreSQL 8.4.8 (Ubuntu 10.04 package)
PsycoPG2 2.4.2 (easy_installed:ed)
Exemple test:
class TestSchema(unittest.TestCase):
test_items = [
# Some parent class products
PrintItem(key='p1', title='Possession', dimension='30x24'),
PrintItem(key='p2', title='Andrzej Żuławski - a director', dimension='22x14'),
DigitalItem(key='d1', title='Every Man His Own University', url='http://www.gutenberg.org/files/36955/36955-h/36955-h.htm'),
DigitalItem(key='d2', title='City Ballads', url='http://www.gutenberg.org/files/36954/36954-h/36954-h.htm'),
]
def testPrintItem(self):
item = self.session.query(PrintItem).filter(PrintItem.key == 'p1').one()
assert item.title == 'Possession', 'Title mismatch'
def testDigitalItem(self):
item2 = self.session.query(DigitalItem).filter(DigitalItem.key == 'd2').one()
assert item2.title == 'City Ballads', 'Title mismatch'
def setUp(self):
Base.metadata.create_all()
self.session = DBSession()
self.session.add_all(self.test_items)
self.session.commit()
def tearDown(self):
self.session.close()
Base.metadata.drop_all()
if __name__ == '__main__':
unittest.main()
UPDATE
Here is the working code snippet.
# -*- coding: utf-8 -*-
import time
import unittest
from sqlalchemy import *
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import *
Base = declarative_base()
engine = create_engine('sqlite:///testdb', echo=False)
DBSession = sessionmaker(bind=engine)
class ItemMixin(object):
"""
Commons attributes for items, ie books, DVD:s...
"""
__tablename__ = 'testitems'
__table_args__ = {'extend_existing':True}
id = Column(Integer, autoincrement=True, primary_key=True)
key = Column(Unicode(16), unique=True, nullable=False)
title = Column(UnicodeText, default=None)
item_type = Column(Unicode(20), default=None)
__mapper_args__ = {'polymorphic_on': item_type}
def __init__(self, key, title=None):
self.key = key
self.title = title
class FooItem(Base, ItemMixin):
foo = Column(UnicodeText, default=None)
__mapper_args__ = {'polymorphic_identity':'foo'}
def __init__(self, foo=None, **kwargs):
ItemMixin.__init__(self, **kwargs)
self.foo = foo
class BarItem(Base, ItemMixin):
bar = Column(UnicodeText, default=None)
__mapper_args__ = {'polymorphic_identity':'bar'}
def __init__(self, bar=None, **kwargs):
ItemMixin.__init__(self, **kwargs)
self.bar = bar
# Tests
class TestSchema(unittest.TestCase):
# Class variables
is_setup = False
session = None
metadata = None
test_items = [
FooItem(key='f1', title='Possession', foo='Hello'),
FooItem(key='f2', title='Andrzej Żuławsk', foo='World'),
BarItem(key='b1', title='Wikipedia', bar='World'),
BarItem(key='b2', title='City Ballads', bar='Hello'),
]
def testFooItem(self):
print ('Test Foo Item')
item = self.__class__.session.query(FooItem).filter(FooItem.key == 'f1').first()
assert item.title == 'Possession', 'Title mismatch'
def testBarItem(self):
print ('Test Bar Item')
item = self.__class__.session.query(BarItem).filter(BarItem.key == 'b2').first()
assert item.title == 'City Ballads', 'Title mismatch'
def setUp(self):
if not self.__class__.is_setup:
self.__class__.session = DBSession()
self.metadata = Base.metadata
self.metadata.bind = engine
self.metadata.drop_all() # Drop table
self.metadata.create_all() # Create tables
self.__class__.session.add_all(self.test_items) # Add data
self.__class__.session.commit() # Commit
self.__class__.is_setup = True
def tearDown(self):
if self.__class__.is_setup:
self.__class__.session.close()
# Just for Python >=2.7 or >=3.2
#classmethod
def setUpClass(cls):
pass
#Just for Python >=2.7 or >=3.2
#classmethod
def tearDownClass(cls):
pass
if __name__ == '__main__':
unittest.main()
The most likely reason for this behavior is the fact that that data is not properly cleaned up between the tests. This explains why when you run only one test, it works.
setUp is called before every test, and tearDown - after.
Depending on what you would like to achieve, you have two options:
create data only once for all test.
In this case you if you had Python-2.7+ or Python-3.2+, you could use tearDownClass method. In your case you can handle it with a boolean class variable to prevent the code you have in setUp running more then once.
re-create data before every test
In this case you need to make sure that in the tearDown you delete all the data. This is what you are not doing right now, and I suspect that when the second test is ran, the call to one() fails not because it does not find an object, but because it finds more two objects matching the criteria.
Check the output of this code to understand the call sequence:
import unittest
class TestSchema(unittest.TestCase):
def testOne(self):
print '==testOne'
def testTwo(self):
print '==testTwo'
def setUp(self):
print '>>setUp'
def tearDown(self):
print '<<tearDown'
#classmethod
def setUpClass():
print '>>setUpClass'
#classmethod
def tearDownClass():
print '<<tearDownClass'
if __name__ == '__main__':
unittest.main()
Output:
>>setUp
==testOne
<<tearDown
>>setUp
==testTwo
<<tearDown
I have this as my tearDown method and it does work fine for my tests:
def tearDown (self):
"""Cleans up after each test case."""
sqlalchemy.orm.clear_mappers()