Save Nested Objects to File in Python3 - json

How can I save this structure of Python objects into a file (preferably JSON)? And how can I load this structure from the file again?
class Nested(object):
def __init__(self, n):
self.name = "Nested Object: " + str(n)
self.state = 3.14159265359
class Nest(object):
def __init__(self):
self.x = 1
self.y = 2
self.objects = []
tree = []
tree.append(Nest())
tree.append(Nest())
tree.append(Nest())
tree[0].objects.append(Nested(1))
tree[0].objects.append(Nested(2))
tree[1].objects.append(Nested(1))
tree[2].objects.append(Nested(7))
tree[2].objects.append(Nested(8))
tree[2].objects.append(Nested(9))

Thanks to the reference to "pickle" I found a well working very simple solution to save my array of objects:
pickle
import pickle
pickle.dump( tree, open( "save.p", "wb" ) )
loaded_objects = pickle.load( open( "save.p", "rb" ) )
jsonpickle
import jsonpickle
frozen = jsonpickle.encode(tree)
with open("save.json", "w") as text_file:
print(frozen, file=text_file)
file = open("save.json", "r")
loaded_objects = jsonpickle.decode(file.read())

If you don't want pickle, nor want to use an external library you can always do it the hard way:
import json
class NestEncoder(json.JSONEncoder):
def default(self, obj):
entry = dict(obj.__dict__)
entry['__class__'] = obj.__class__.__name__
return entry
class NestDecoder(json.JSONDecoder):
def __init__(self):
json.JSONDecoder.__init__(self, object_hook=self.dict_to_object)
def dict_to_object(self, dictionary):
if dictionary.get("__class__") == "Nested":
obj = Nested.__new__(Nested)
elif dictionary.get("__class__") == "Nest":
obj = Nest.__new__(Nest)
else:
return dictionary
for key, value in dictionary.items():
if key != '__class__':
setattr(obj, key, value)
return obj
with open('nest.json', 'w') as file:
json.dump(tree, file, cls=NestEncoder)
with open('nest.json', 'r') as file:
tree2 = json.load(file, cls=NestDecoder)
print("Smoke test:")
print(tree[0].objects[0].name)
print(tree2[0].objects[0].name)
Assigning the the attributes to the classes doesn't have to be done dynamically with setattr() you can also do it manually.
There are probably plenty of pitfalls with doing it like this, so be careful.

Related

Bizarre Environment-dependent Bad Request 400 error

I'm writing a program to convert a repository into a Docker with an API based on some specification files. When I run the app on my Macbook's base environment, the computer-generated API works perfectly with both gunicorn and uwsgi. However, within the miniconda-based docker container, it failed with Bad Request 400: The browser (or proxy) sent a request that this server could not understand. My goal is to eliminate this error. Obviously, this has to do with the versions of some dependency or set of dependencies. Interestingly, the last endpoint in the API, which has a request parser within a namespace with no arguments, works perfectly, unlike the two other endpoints in the default namespace that do have arguments.
The API is built on flask_restx and uses reqparse.
The API code is here:
from flask_restx import Api, Resource, Namespace, reqparse, inputs
import flask
import process
from load_data import store_data
app = flask.Flask("restful_api")
api = Api(app, title="My API", description="This is an extremely useful API for performing tasks you would do with an API.", version="3.14")
data = {}
data.update(store_data())
class DefaultClass():
def __init__(self):
self.data = data
def _replace_get(self, **args):
default_args = {}
args = {**default_args, **args}
return process.replace(**args)
def _find_get(self, **args):
default_args = {"data": self.data["data"]}
args = {**default_args, **args}
return process.find_in_data_string(**args)
def set_up_worker():
global defaultClass
defaultClass = DefaultClass()
set_up_worker()
_replaceGetParser = reqparse.RequestParser()
_replaceGetParser.add_argument("txt",
type=str,
required=True,
help="Text to search ")
_replaceGetParser.add_argument("old",
type=str,
required=True,
help="Substring to replace ")
_replaceGetParser.add_argument("new",
type=str,
required=True,
help="Replacement for old ")
_replaceGetParser.add_argument("irrelevant_parameter",
type=int,
required=False,
default=5,
help="")
_replaceGetParser.add_argument("smart_casing",
type=inputs.boolean,
required=False,
default=True,
help="True if we should infer replacement capitalization from original casing. ")
_replaceGetParser.add_argument("case_sensitive",
type=inputs.boolean,
required=False,
default=True,
help="True if we should only replace case-sensitive matches ")
_findGetParser = reqparse.RequestParser()
_findGetParser.add_argument("window",
type=int,
required=False,
default=5,
help="Number of characters before and after first match to return ")
_findGetParser.add_argument("txt",
type=str,
required=False,
default="quick",
help="Your search term ")
#api.route('/replace', endpoint='replace', methods=['GET'])
#api.doc('defaultClass')
class ReplaceFrontend(Resource):
#api.expect(_replaceGetParser)
def get(self):
args = _replaceGetParser.parse_args()
return defaultClass._replace_get(**args)
#api.route('/find', endpoint='find', methods=['GET'])
#api.doc('defaultClass')
class FindFrontend(Resource):
#api.expect(_findGetParser)
def get(self):
args = _findGetParser.parse_args()
return defaultClass._find_get(**args)
retrievalNamespace = Namespace("retrieval", description="Data retrieval operations")
class RetrievalNamespaceClass():
def __init__(self):
self.data = data
def _retrieval_retrieve_data_get(self, **args):
default_args = {"data": self.data["data"]}
args = {**default_args, **args}
return process.return_data(**args)
def set_up_retrieval_worker():
global retrievalNamespaceClass
retrievalNamespaceClass = RetrievalNamespaceClass()
set_up_retrieval_worker()
_retrieval_retrieve_dataGetParser = reqparse.RequestParser()
#retrievalNamespace.route('/retrieval/retrieve_data', endpoint='retrieval/retrieve_data', methods=['GET'])
#retrievalNamespace.doc('retrievalNamespaceClass')
class Retrieval_retrieve_dataFrontend(Resource):
#retrievalNamespace.expect(_retrieval_retrieve_dataGetParser)
def get(self):
args = _retrieval_retrieve_dataGetParser.parse_args()
return retrievalNamespaceClass._retrieval_retrieve_data_get(**args)
api.add_namespace(retrievalNamespace)
I have had this problem with both pip-installed gunicorn and conda-installed uwsgi. I'm putting the file imported by the API at the end, since I think it is likely irrelevant what the function definitions are.
import numpy as np
import pandas as pd
import re
from subprocess import Popen, PIPE
from flask_restx import abort
def replace(txt: str = '', # apireq
old: str = '', # apireq
new: str = '', # apireq
case_sensitive: bool = True,
smart_casing: bool = True,
irrelevant_parameter: int = 5):
"""
Search and replace within a string, as long as the string and replacement
contain no four letter words.
arguments:
txt: Text to search
old: Substring to replace
new: Replacement for old
case_sensitive: True if we should only replace case-sensitive matches
smart_casing: True if we should infer replacement capitalization
from original casing.
return
return value
"""
four_letter_words = [re.match('[a-zA-Z]{4}$', word).string
for word in ('%s %s' % (txt, new)).split()
if re.match('[a-zA-Z]{4}$', word)]
if four_letter_words:
error_message = ('Server refuses to process four letter word(s) %s'
% ', '.join(four_letter_words[:5])
+ (', etc' if len(four_letter_words) > 5 else ''))
abort(403, custom=error_message)
return_value = {}
if not case_sensitive:
return_value['output'] = txt.replace(old, new)
else:
lowered = txt.replace(old, old.lower())
return_value['output'] = lowered.replace(old.lower(), new)
return return_value
def find_in_data_string(txt: str = "quick", # req
window: int = 5,
data=None): # noapi
"""
Check if there is a match for your search string in our extensive database,
and return the position of the first match with the surrounding text.
arguments:
txt: Your search term
data: The server's text data
window: Number of characters before and after first match to return
"""
return_value = {}
if txt in data:
idx = data.find(txt)
min_idx = max(idx-window, 0)
max_idx = min(idx+len(txt)+window, len(data)-1)
return_value['string_found'] = True
return_value['position'] = idx
return_value['surrounding_string'] = data[min_idx:max_idx]
return_value['surrounding_string_indices'] = [min_idx, max_idx]
else:
return_value = {['string_found']: False}
return return_value
def return_data(data=None): # noapi
"""
Return all the data in our text database.
"""
with Popen(['which', 'aws'], shell=True, stdout=PIPE) as p:
output = p.stdout.read()
try:
assert not output.strip()
except AssertionError:
abort(503, custom='The server is incorrectly configured.')
return_value = {'data': data}
return return_value

ROS service failed to save files

I want to have a service 'save_readings' that automatically saves data from a rostopic into a file. But each time the service gets called, it doesn't save any file.
I've tried to run those saving-file code in python without using a rosservice and the code works fine.
I don't understand why this is happening.
#!/usr/bin/env python
# license removed for brevity
import rospy,numpy
from std_msgs.msg import String,Int32MultiArray,Float32MultiArray,Bool
from std_srvs.srv import Empty,EmptyResponse
import geometry_msgs.msg
from geometry_msgs.msg import WrenchStamped
import json
# import settings
pos_record = []
wrench_record = []
def ftmsg2listandflip(ftmsg):
return [ftmsg.wrench.force.x,ftmsg.wrench.force.y,ftmsg.wrench.force.z, ftmsg.wrench.torque.x,ftmsg.wrench.torque.y,ftmsg.wrench.torque.z]
def callback_pos(data):
global pos_record
pos_record.append(data.data)
def callback_wrench(data):
global wrench_record
ft = ftmsg2listandflip(data)
wrench_record.append([data.header.stamp.to_sec()] + ft)
def exp_listener():
stop_sign = False
rospy.Subscriber("stage_pos", Float32MultiArray, callback_pos)
rospy.Subscriber("netft_data", WrenchStamped, callback_wrench)
rospy.spin()
def start_read(req):
global pos_record
global wrench_record
pos_record = []
wrench_record = []
return EmptyResponse()
def save_readings(req):
global pos_record
global wrench_record
filename = rospy.get_param('save_file_name')
output_data = {'pos_list':pos_record, 'wrench_list': wrench_record }
rospy.loginfo("output_data %s",output_data)
with open(filename, 'w') as outfile: # write data to 'data.json'
print('dumping json file')
json.dump(output_data, outfile) #TODO: find out why failing to save the file.
outfile.close()
print("file saved")
rospy.sleep(2)
return EmptyResponse()
if __name__ == '__main__':
try:
rospy.init_node('lisener_node', log_level = rospy.INFO)
s_1 = rospy.Service('start_read', Empty, start_read)
s_1 = rospy.Service('save_readings', Empty, save_readings)
exp_listener()
print ('mylistener ready!')
except rospy.ROSInterruptException:
pass
Got it. I need to specify a path for the file to be saved.
save_path = '/home/user/catkin_ws/src/motionstage/'
filename = save_path + filename

Does IOError: [Errno 2] No such file or directory: mean the file hasn't been written?

I'm using Tweepy for the first time. Currently getting this error
---------------------------------------------------------------------------
IOError Traceback (most recent call last)
<ipython-input-11-cdd7ebe0c00f> in <module>()
----> 1 data_json = io.open('raw_tweets.json', mode='r', encoding='utf-8').read() #reads in the JSON file
2 data_python = json.loads(data_json)
3
4 csv_out = io.open('tweets_out_utf8.csv', mode='w', encoding='utf-8') #opens csv file
IOError: [Errno 2] No such file or directory: 'raw_tweets.json'
I've got a feeling that the code I've got isn't working. For example print(status) doesn't print anything. Also I see no saved CSV or JSON file in the directory.
I'm a newbie so any help/documentation you can offer would be great!
import time
from tweepy import Stream
from tweepy import OAuthHandler
from tweepy.streaming import StreamListener
import os
import json
import csv
import io
from pymongo import MongoClient
ckey = 'blah'
consumer_secret = 'blah'
access_token_key = 'blah'
access_token_secret = 'blah'
#start_time = time.time() #grabs the system time
keyword_list = ['keyword'] #track list
#Listener Class Override
class listener(StreamListener):
def __init__(self, start_time, time_limit=60):
self.time = start_time
self.limit = time_limit
self.tweet_data = []
def on_data(self, data):
saveFile = io.open('raw_tweets.json', 'a', encoding='utf-8')
while (time.time() - self.time) < self.limit:
try:
self.tweet_data.append(data)
return True
except BaseException, e:
print 'failed ondata,', str(e)
time.sleep(5)
pass
saveFile = io.open('raw_tweets.json', 'w', encoding='utf-8')
saveFile.write(u'[\n')
saveFile.write(','.join(self.tweet_data))
saveFile.write(u'\n]')
saveFile.close()
exit()
def on_error(self, status):
print status
class listener(StreamListener):
def __init__(self, start_time, time_limit=10):
self.time = start_time
self.limit = time_limit
def on_data(self, data):
while (time.time() - self.time) < self.limit:
print(data)
try:
client = MongoClient('blah', 27017)
db = client['blah']
collection = db['blah']
tweet = json.loads(data)
collection.insert(tweet)
return True
except BaseException as e:
print('failed ondata,')
print(str(e))
time.sleep(5)
pass
exit()
def on_error(self, status):
print(status)
data_json = io.open('raw_tweets.json', mode='r', encoding='utf-8').read() #reads in the JSON file
data_python = json.loads(data_json)
csv_out = io.open('tweets_out_utf8.csv', mode='w', encoding='utf-8') #opens csv file
UPDATED: Creates file but file is empty
import tweepy
import datetime
auth = tweepy.OAuthHandler('xxx', 'xxx')
auth.set_access_token('xxx', 'xxx')
class listener(tweepy.StreamListener):
def __init__(self, timeout, file_name, *args, **kwargs):
super(listener, self).__init__(*args, **kwargs)
self.start_time = None
self.timeout = timeout
self.file_name = file_name
self.tweet_data = []
def on_data(self, data):
if self.start_time is None:
self.start_time = datetime.datetime.now()
while (datetime.datetime.now() - self.start_time).seconds < self.timeout:
with open(self.file_name, 'a') as data_file:
data_file.write('\n')
data_file.write(data)
def on_error(self, status):
print status
l = listener(60, 'stack_raw_tweets.json')
mstream = tweepy.Stream(auth=auth, listener=l)
mstream.filter(track=['python'], async=True)
You are not creating a Stream for the listener. The last but one line of the code below does that. Followed by that you have to start the Stream, which is the last line. I must warn you that storing this in mongodb is the right thing to do as the file that I am storing it seems to grow easily to several GB. Also the file is not exactly a json. Each line in the file is a json. You must tweak it to your needs.
import tweepy
import datetime
auth = tweepy.OAuthHandler(consumer_key, consumer_secret)
auth.set_access_token(access_token, access_secret)
class listener(tweepy.StreamListener):
def __init__(self, timeout, file_name, *args, **kwargs):
super(listener, self).__init__(*args, **kwargs)
self.start_time = None
self.timeout = timeout
self.file_name = file_name
self.tweet_data = []
def on_data(self, data):
if self.start_time is None:
self.start_time = datetime.datetime.now()
while (datetime.datetime.now() - self.start_time).seconds < self.timeout:
with open(self.file_name, 'a') as data_file:
data_file.write('\n')
data_file.write(data)
def on_error(self, status):
print status
l = listener(60, 'raw_tweets.json')
mstream = tweepy.Stream(auth=auth, listener=l)
mstream.filter(track=['python'], async=True)

pyqt4 QTableView in QMainWindow with csv input and headers

I am working with a QMainWindow and adding a QTableView widget. The table is to be filled with data from a csv file. The csv file first row has the headers, but I cannot find how to write that row into the headers. Even inputting a test header list does not work.
Also I want to reverse sort on the "time" column.
Here is code restricted to mostly the table:
import sys
import csv
from PyQt4 import QtGui
from PyQt4.QtCore import *
from array import *
class UserWindow(QtGui.QMainWindow):
def __init__(self, parent=None):
super(UserWindow, self).__init__()
self.specModel = QtGui.QStandardItemModel(self)
self.specList = self.createSpecTable()
self.initUI()
def specData(self):
with open('testFile.csv', 'rb') as csvInput:
for row in csv.reader(csvInput):
if row > 0:
items = [QtGui.QStandardItem(field) for field in row]
self.specModel.appendRow(items)
def createSpecTable(self):
self.specTable = QtGui.QTableView()
# This is a test header - different from what is needed
specHdr = ['Test', 'Date', 'Time', 'Type']
self.specData()
specM = specTableModel(self.specModel, specHdr, self)
self.specTable.setModel(specM)
self.specTable.setShowGrid(False)
vHead = self.specTable.verticalHeader()
vHead.setVisible(False)
hHead = self.specTable.horizontalHeader()
hHead.setStretchLastSection(True)
self.specTable.sortByColumn(3, Qt.DescendingOrder)
return self.specTable
def initUI(self):
self.ctr_frame = QtGui.QWidget()
self.scnBtn = QtGui.QPushButton("Sample")
self.refBtn = QtGui.QPushButton("Reference")
self.stpBtn = QtGui.QPushButton("Blah")
# List Window
self.specList.setModel(self.specModel)
# Layout of Widgets
pGrid = QtGui.QGridLayout()
pGrid.setSpacing(5)
pGrid.addWidget(self.scnBtn, 3, 0, 1, 2)
pGrid.addWidget(self.refBtn, 3, 2, 1, 2)
pGrid.addWidget(self.stpBtn, 3, 4, 1, 2)
pGrid.addWidget(self.specList, 10, 0, 20, 6)
self.ctr_frame.setLayout(pGrid)
self.setCentralWidget(self.ctr_frame)
self.statusBar()
self.setGeometry(300, 300, 400, 300)
self.setWindowTitle('Test')
class specTableModel(QAbstractTableModel):
def __init__(self, datain, headerdata, parent=None, *args):
QAbstractTableModel.__init__(self, parent, *args)
self.arraydata = datain
self.headerdata = headerdata
def rowCount(self, parent):
return len(self.arraydata)
def columnCount(self, parent):
return len(self.arraydata[0])
def data(self, index, role):
if not index.isValid():
return QVariant()
elif role != Qt.DisplayRole:
return QVariant()
return QVariant(self.arraydata[index.row()][index.column()])
def headerData(self, col, orientation, role):
if orientation == Qt.Horizontal and role == Qt.DisplayRole:
return self.headerdata[col]
return None
def main():
app = QtGui.QApplication(sys.argv)
app.setStyle(QtGui.QStyleFactory.create("plastique"))
ex = UserWindow()
ex.show()
sys.exit(app.exec_())
if __name__ == '__main__':
main()
and here is a really short csv file:
Run,Date,Time,Comment
data1,03/03/2014,00:04,Reference
data2,03/03/2014,02:00,Reference
data5,03/03/2014,02:08,Sample
data6,03/03/2014,13:57,Sample
Also the rowCount & columnCount definitions do not work.
Worked out answers to what I posted: Wrote a 'getHeader' function simply to read the first line of the csv file and returned the list. Added the following to the createSpecTable function:
specHdr = self.getHeader()
self.specModel.setHorizontalHeaderLabels(specHdr)
self.specModel.sort(2, Qt.DescendingOrder)
The last statement solved the reverse sort problem. The header line from the csv file was removed from the table by adding a last line to the specData function:
self.specModelremoveRow(0).
Finally the rowCount and columnCount were corrected with:
def rowCount(self, parent):
return self.arraydata.rowCount()
def columnCount(self, parent):
return self.arraydata.columnCount()

SQLAlchemy session voes in unittest

I've just started using SQLAlchemy a few days ago and right now I'm stuck with a problem that I hope anyone can shed some light on before I loose all my hair.
When I run a unittest, see snippet below, only the first test in the sequence is passing. The test testPhysicalPrint works just fine, but testRecordingItem fails with NoResultFound exception - No row was found for one(). But if I remove testPhysicalPrint from the test class, then testRecordingItem works.
I assume that the problem has something to do with the session, but I can't really get a grip of it.
In case anyone wonders, the setup is as follows:
Python 3.1 (Ubuntu 10.04 package)
SQLAlchemy 0.7.2 (easy_install:ed)
PostgreSQL 8.4.8 (Ubuntu 10.04 package)
PsycoPG2 2.4.2 (easy_installed:ed)
Exemple test:
class TestSchema(unittest.TestCase):
test_items = [
# Some parent class products
PrintItem(key='p1', title='Possession', dimension='30x24'),
PrintItem(key='p2', title='Andrzej Żuławski - a director', dimension='22x14'),
DigitalItem(key='d1', title='Every Man His Own University', url='http://www.gutenberg.org/files/36955/36955-h/36955-h.htm'),
DigitalItem(key='d2', title='City Ballads', url='http://www.gutenberg.org/files/36954/36954-h/36954-h.htm'),
]
def testPrintItem(self):
item = self.session.query(PrintItem).filter(PrintItem.key == 'p1').one()
assert item.title == 'Possession', 'Title mismatch'
def testDigitalItem(self):
item2 = self.session.query(DigitalItem).filter(DigitalItem.key == 'd2').one()
assert item2.title == 'City Ballads', 'Title mismatch'
def setUp(self):
Base.metadata.create_all()
self.session = DBSession()
self.session.add_all(self.test_items)
self.session.commit()
def tearDown(self):
self.session.close()
Base.metadata.drop_all()
if __name__ == '__main__':
unittest.main()
UPDATE
Here is the working code snippet.
# -*- coding: utf-8 -*-
import time
import unittest
from sqlalchemy import *
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import *
Base = declarative_base()
engine = create_engine('sqlite:///testdb', echo=False)
DBSession = sessionmaker(bind=engine)
class ItemMixin(object):
"""
Commons attributes for items, ie books, DVD:s...
"""
__tablename__ = 'testitems'
__table_args__ = {'extend_existing':True}
id = Column(Integer, autoincrement=True, primary_key=True)
key = Column(Unicode(16), unique=True, nullable=False)
title = Column(UnicodeText, default=None)
item_type = Column(Unicode(20), default=None)
__mapper_args__ = {'polymorphic_on': item_type}
def __init__(self, key, title=None):
self.key = key
self.title = title
class FooItem(Base, ItemMixin):
foo = Column(UnicodeText, default=None)
__mapper_args__ = {'polymorphic_identity':'foo'}
def __init__(self, foo=None, **kwargs):
ItemMixin.__init__(self, **kwargs)
self.foo = foo
class BarItem(Base, ItemMixin):
bar = Column(UnicodeText, default=None)
__mapper_args__ = {'polymorphic_identity':'bar'}
def __init__(self, bar=None, **kwargs):
ItemMixin.__init__(self, **kwargs)
self.bar = bar
# Tests
class TestSchema(unittest.TestCase):
# Class variables
is_setup = False
session = None
metadata = None
test_items = [
FooItem(key='f1', title='Possession', foo='Hello'),
FooItem(key='f2', title='Andrzej Żuławsk', foo='World'),
BarItem(key='b1', title='Wikipedia', bar='World'),
BarItem(key='b2', title='City Ballads', bar='Hello'),
]
def testFooItem(self):
print ('Test Foo Item')
item = self.__class__.session.query(FooItem).filter(FooItem.key == 'f1').first()
assert item.title == 'Possession', 'Title mismatch'
def testBarItem(self):
print ('Test Bar Item')
item = self.__class__.session.query(BarItem).filter(BarItem.key == 'b2').first()
assert item.title == 'City Ballads', 'Title mismatch'
def setUp(self):
if not self.__class__.is_setup:
self.__class__.session = DBSession()
self.metadata = Base.metadata
self.metadata.bind = engine
self.metadata.drop_all() # Drop table
self.metadata.create_all() # Create tables
self.__class__.session.add_all(self.test_items) # Add data
self.__class__.session.commit() # Commit
self.__class__.is_setup = True
def tearDown(self):
if self.__class__.is_setup:
self.__class__.session.close()
# Just for Python >=2.7 or >=3.2
#classmethod
def setUpClass(cls):
pass
#Just for Python >=2.7 or >=3.2
#classmethod
def tearDownClass(cls):
pass
if __name__ == '__main__':
unittest.main()
The most likely reason for this behavior is the fact that that data is not properly cleaned up between the tests. This explains why when you run only one test, it works.
setUp is called before every test, and tearDown - after.
Depending on what you would like to achieve, you have two options:
create data only once for all test.
In this case you if you had Python-2.7+ or Python-3.2+, you could use tearDownClass method. In your case you can handle it with a boolean class variable to prevent the code you have in setUp running more then once.
re-create data before every test
In this case you need to make sure that in the tearDown you delete all the data. This is what you are not doing right now, and I suspect that when the second test is ran, the call to one() fails not because it does not find an object, but because it finds more two objects matching the criteria.
Check the output of this code to understand the call sequence:
import unittest
class TestSchema(unittest.TestCase):
def testOne(self):
print '==testOne'
def testTwo(self):
print '==testTwo'
def setUp(self):
print '>>setUp'
def tearDown(self):
print '<<tearDown'
#classmethod
def setUpClass():
print '>>setUpClass'
#classmethod
def tearDownClass():
print '<<tearDownClass'
if __name__ == '__main__':
unittest.main()
Output:
>>setUp
==testOne
<<tearDown
>>setUp
==testTwo
<<tearDown
I have this as my tearDown method and it does work fine for my tests:
def tearDown (self):
"""Cleans up after each test case."""
sqlalchemy.orm.clear_mappers()