Scrapy MySQL pipeline: spider closed before pipeline finished - mysql

I use scrapy to crawl a page which contains a list of items, and I save each of the item in MySQL databases.
But the problem is that I found spider closed before all items are stored in mysql. Each time I ran the spider the result count is different.
Could you please help let me know how to solve this?
Below is my sample code:
Spider
class FutianSpider(scrapy.Spider):
name = 'futian_soufang'
allowed_domain = ["fang.com"]
start_urls = []
def __init__(self, category=None, *args, **kwargs):
self.count = 0
pass
def closed(self, reason):
print "*" * 20 + str(self.count)
def start_requests(self):
url = "http://fangjia.fang.com/pghouse-c0sz/a085-h321-i3{}/"
response = requests.get(url.format(1))
response.encoding = 'gb2312'
strpages = Selector(text=response.text).xpath('//p[contains(#class, "pages")]/span[last()]/a/text()').extract()
# print response.text
pages = int(strpages[0])
for num in range(1, pages + 1):
yield scrapy.Request(url.format(num), callback=self.parse_page)
def parse_page(self, response):
houses = response.xpath("//div[#class='list']//div[#class='house']")
for house in houses:
# house = Selector(house.decode("UTF-8", 'ignore'))
self.count += 1
housespan_hyperlink = house.xpath(".//span[#class='housetitle']/a")
house_title = housespan_hyperlink.xpath("text()").extract()[0].strip()
house_link_rel = housespan_hyperlink.xpath("#href").extract()[0].strip()
house_link = response.urljoin(house_link_rel)
# if isinstance(house_link_rel, list) and len(house_link_rel) > 0:
# house_link = response.urljoin(house_link_rel)
address = house.xpath(".//span[#class='pl5']/text()").extract()[0].strip()
esf_keyword = u'二手房'
esf_span = house.xpath(".//span[contains(text(),'%s')]" % (esf_keyword))
esf_number = esf_span.xpath("./a/text()").extract()[0].strip()
esf_number = int(re.findall(r"\d+", esf_number)[0])
esf_link = esf_span.xpath("./a/#href").extract()[0].strip()
zf_hyperlink = house.xpath(".//span[#class='p110']/a")
zf_number = zf_hyperlink.xpath("text()").extract()[0].strip()
zf_number = int(re.findall(r"\d+", zf_number)[0])
zf_link = zf_hyperlink.xpath("#href").extract()[0].strip()
price = 0
try:
price = int(house.xpath(".//span[#class='price']/text()").extract()[0].strip())
except:
None
change = 0.0
try:
increase_span = house.xpath(".//span[contains(#class, 'nor')]")
changetext = increase_span.xpath("text()").extract()[0].strip()
change = float(changetext[:changetext.index('%')])
if len(increase_span.css(".green-down")) > 0:
change *= -1
except:
None
print house_title, house_link, address, esf_number, esf_link, zf_number, zf_link, price, change
item = XiaoquItem(
title=house_title,
url=house_link,
address=address,
esf_number=esf_number,
esf_link=esf_link,
zf_number=zf_number,
zf_link=zf_link,
price=price,
change=change
)
yield item
Item:
class XiaoquItem(Item):
# define the fields for your item here like:
title = Field()
url = Field()
address = Field()
esf_number = Field()
esf_link = Field()
zf_number = Field()
zf_link = Field()
price = Field()
change = Field()
Pipeline:
class MySQLPipeLine(object):
def __init__(self):
settings = get_project_settings()
dbargs = settings.get('DB_CONNECT')
db_server = settings.get('DB_SERVER')
dbpool = adbapi.ConnectionPool(db_server, **dbargs)
self.dbpool = dbpool
def close_spider(self, spider):
self.dbpool.close()
def process_item(self, item, spider):
if isinstance(item, XiaoquItem):
self._process_plot(item)
elif isinstance(item, PlotMonthlyPriceItem):
self._process_plot_price(item)
return item
def _process_plot(self, item):
# run db query in thread pool
query = self.dbpool.runInteraction(self._conditional_insert, item)
query.addErrback(self._handle_error, item)
# query.addBoth(lambda _: item)
def _conditional_insert(self, conn, item):
# create record if doesn't exist.
# all this block run on it's own thread
conn.execute("select * from houseplot where title = %s", item["title"])
result = conn.fetchone()
if result:
log.msg("Item already stored in db: %s" % item, level = log.DEBUG)
else:
conn.execute("insert into houseplot(title, url, address, esf_number, esf_link, zf_number, zf_link, price, price_change, upsert_time) values (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)", \
(item["title"], item["url"], item["address"], int(item["esf_number"]), item["esf_link"], item["zf_number"], item["zf_link"], item["price"], item["change"], datetime.datetime.now())
)
log.msg("Item stored in db: %s" % item, level=log.DEBUG)
def _handle_error(self, e):
log.err(e)
def _process_plot_price(self, item):
query = self.dbpool.runInteraction(self._conditional_insert_price, item)
query.addErrback(self._handle_error, item)
def _conditional_insert_price(self, conn, item):
# create record if doesn't exist.
# all this block run on it's own thread
conn.execute("select * from houseplot_monthly_price where title = %s and price_date= %s", (item["title"], item["price_date"]))
result = conn.fetchone()
if result:
log.msg("Price Item already stored in db: %s" % item, level=log.DEBUG)
else:
conn.execute(
"insert into houseplot_monthly_price(title, price_date, price) values (%s, %s, %s)", (item["title"], item["price_date"], item["price"])
)
log.msg("Price Item stored in db: %s" % item, level=log.DEBUG)

Related

Remove line of text from a Json file

I am programming an economic bot with items, inventory, currency and much more, but now when I sell my item in my inventory it is still there in the embed, but I want that when I sell my item and the number in the inventory is 0 that this is then no longer displayed in the inventory embed, so to speak, is removed from the Json file
My code for the sell command:
#client.command()
async def sell(ctx,item,amount = 1):
await open_account(ctx.author)
res = await sell_this(ctx.author,item,amount)
em1 = discord.Embed(title=f"{ctx.author.name}",
description="Das item konnte nicht in deinem Inventar gefunden werden",
color=0xe67e22)
em1.set_thumbnail(url=ctx.author.avatar_url)
em2 = discord.Embed(title=f"{ctx.author.name}",
description=f"Du hast keine {amount} {item} in deinem inventar",
color=0xe67e22)
em2.set_thumbnail(url=ctx.author.avatar_url)
em3 = discord.Embed(title=f"{ctx.author.name}",
description=f"Du hast das Item: **{item}** nicht in deinem Inventar",
color=0xe67e22)
em3.set_thumbnail(url=ctx.author.avatar_url)
em4 = discord.Embed(title=f"{ctx.author.name}",
description=f"Du hast {amount} {item} gekauft",
color=0xe67e22)
em4.set_thumbnail(url=ctx.author.avatar_url)
if not res[0]:
if res[1]==1:
await ctx.send(embed=em1)
return
if res[1]==2:
await ctx.send(embed=em2)
return
if res[1]==3:
await ctx.send(embed=em3)
return
await ctx.send(embed=em4)
async def sell_this(user,item_name,amount,price = None):
item_name = item_name.lower()
name_ = None
for item in mainshop:
name = item["name"].lower()
if name == item_name:
name_ = name
if price==None:
price = 0.9* item["price"]
break
if name_ == None:
return [False,1]
cost = price*amount
users = await get_bank_data()
bal = await update_bank(user)
try:
index = 0
t = None
for thing in users[str(user.id)]["bag"]:
n = thing["item"]
if n == item_name:
old_amt = thing["amount"]
new_amt = old_amt - amount
if new_amt < 0:
return [False,2]
users[str(user.id)]["bag"][index]["amount"] = new_amt
t = 1
break
index+=1
if t == None:
return [False,3]
except:
return [False,3]
with open("Bank.json","w") as f:
json.dump(users,f)
await update_bank(user,cost,"wallet")
return [True,"Worked"]
I hope someone can help me
Use del to delete the item in question from the json data:
try:
index = 0
t = None
for thing in users[str(user.id)]["bag"]:
n = thing["item"]
if n == item_name:
old_amt = thing["amount"]
new_amt = old_amt - amount
if new_amt < 0:
return [False,2]
elif new_amt == 0: # Check if amount is 0
del users[str(user.id)]["bag"][index] # Delete item from bag
else:
users[str(user.id)]["bag"][index]["amount"] = new_amt
t = 1
break
index+=1
if t == None:
return [False,3]
except:
return [False,3]
This completely removes the item and amount from the 'bag'.

Why does my agent always takes a same action in DQN - Reinforcement Learning

I have trained an RL agent using DQN algorithm. After 20000 episodes my rewards are converged. Now when I test this agent, the agent is always taking the same action , irrespective of state. I find this very weird. Can someone help me with this. Is there a reason, anyone can think of why is the agent behaving this way?
Reward plot
When I test the agent
state = env.reset()
print('State: ', state)
state_encod = np.reshape(state, [1, state_size])
q_values = model.predict(state_encod)
action_key = np.argmax(q_values)
print(action_key)
print(index_to_action_mapping[action_key])
print(q_values[0][0])
print(q_values[0][action_key])
q_values_plotting = []
for i in range(0,action_size):
q_values_plotting.append(q_values[0][i])
plt.plot(np.arange(0,action_size),q_values_plotting)
Every time it gives the same q_values plot, even though state initialized is different every time.Below is the q_Value plot.
Testing:
code
test_rewards = []
for episode in range(1000):
terminal_state = False
state = env.reset()
episode_reward = 0
while terminal_state == False:
print('State: ', state)
state_encod = np.reshape(state, [1, state_size])
q_values = model.predict(state_encod)
action_key = np.argmax(q_values)
action = index_to_action_mapping[action_key]
print('Action: ', action)
next_state, reward, terminal_state = env.step(state, action)
print('Next_state: ', next_state)
print('Reward: ', reward)
print('Terminal_state: ', terminal_state, '\n')
print('----------------------------')
episode_reward += reward
state = deepcopy(next_state)
print('Episode Reward' + str(episode_reward))
test_rewards.append(episode_reward)
plt.plot(test_rewards)
Thanks.
Adding environment
import gym
import rom_vav_150mm_polyreg as rom
import numpy as np
import random
class VAVenv(gym.Env):
def __init__(self):
# Zone temperature set point and limits
self.temp_sp = 24
self.temp_sp_max = 24.5
self.temp_sp_min = 23.7
# no; of hours in an episode and time interval for each step
self.MAXSTEPS = 11
self.time_interval = 5./60. #in hrs
# constants
self.zone_volume = 775
def step(self,state,action):
# state -> Time, Volume, Load, SAT ,RAT
# action -> CFM
action_cfm = action[0]
# damper_opening = state[2]
load = state[2]
sat = state[3]
current_temp = state[4]
#input
inputs_rat = np.array([load,action_cfm, self.zone_volume,current_temp,sat])
'''
AFTER 5 MINUTES
'''
#output
output = [self.KStep + self.time_interval,self.zone_volume,rom.load(self.KStep + self.time_interval),
sat,rom.rat(inputs_rat)]
#reward calculation
thermal_coefficient = -0.1
zone_temperature = output[4]
if zone_temperature < self.temp_sp_min:
temp_penalty = self.temp_sp_min - zone_temperature
elif zone_temperature > self.temp_sp_max:
temp_penalty = zone_temperature - self.temp_sp_max
else :
temp_penalty = -10
reward = thermal_coefficient * temp_penalty
# create next step
next_state = np.array(output)
# increment simulation step count
self.KStep += self.time_interval
# done - end of one episode, when kSteps reaches the maximum steps in an episode
done = False
if self.KStep > self.MAXSTEPS:
done = True
return next_state,reward,done
def reset(self):
self.KStep = 0
# initialize all the values of a state
initial_rat = random.uniform(23,27)
initial_sat = random.uniform(12,14)
# return a state
return np.array([self.KStep,self.zone_volume,
rom.load(self.KStep),initial_sat,initial_rat])

How to add extra fields in ValueQuerySet (Django)?

Basically, I want to convert the query_set to JSON. But I also want to add one more field something like size = some number in the query_set which is not present in the query_set attributes (it is computed attribute). Can you tell me how to do it?
query_set = PotholeCluster.objects.all().values('bearing', 'center_lat', 'center_lon', 'grid_id')
return JsonResponse(list(query_set), safe=False)
I tried the code below. It works, but I would like to know if there is any cleaner way to do this.
query_set = PotholeCluster.objects.all()
response_list = []
for pc in query_set:
d = {}
d['bearing'] = pc.get_bearing()
d['center_lat'] = pc.center_lat
d['center_lon'] = pc.center_lat
d['grid_id'] = pc.grid_id
d['size'] = pc.pothole_set.all().count()
response_list.append(d)
serialized = json.dumps(response_list)
return HttpResponse(serialized, content_type='application/json')
class PotholeCluster(models.Model):
center_lat = models.FloatField(default=0)
center_lon = models.FloatField(default=0)
snapped_lat = models.FloatField(default=0)
snapped_lon = models.FloatField(default=0)
size = models.IntegerField(default=-1)
# avgspeed in kmph
speed = models.FloatField(default=-1)
# in meters
accuracy = models.FloatField(default=-1)
# avg bearing in degree
bearing = models.FloatField(default=-1)
grid = models.ForeignKey(
Grid,
on_delete=models.SET_NULL,
null=True,
blank=True
)
def __str__(self):
raw_data = serialize('python', [self])
output = json.dumps(raw_data[0]['fields'])
return "pk = {}|{}".format(self.id, output)
def get_bearing(self):
if self.bearing != -1:
return self.bearing
potholes = self.pothole_set.all()
bearings = [pothole.location.bearing for pothole in potholes]
bearings.sort()
i = 0
if bearings[-1] >= 350:
while bearings[-1] - bearings[i] >= 340:
if bearings[i] <= 10:
bearings[i] += 360
i += 1
self.bearing = sum(bearings) / len(bearings) % 360
self.save()
return self.bearing
def get_size(self):
if self.size != -1:
return self.size
self.size = len(self.pothole_set.all())
self.save()
return self.size

Scrapy returns no output - just a [

I'm trying to run the spider found in this crawler and for simplicity sake I'm using this start_url because it is just a list of 320 movies. (So, the crawler won't run for 5 hours as given in the github page).
I crawl using scrapy crawl imdb -o output.json but the output.json file contains nothing. It has just a [ in it.
import scrapy
from texteval.items import ImdbMovie, ImdbReview
import urlparse
import math
import re
class ImdbSpider(scrapy.Spider):
name = "imdb"
allowed_domains = ["imdb.com"]
start_urls = [
# "http://www.imdb.com/chart/top",
# "http://www.imdb.com/chart/bottom"
"http://www.imdb.com/search/title?countries=csxx&sort=moviemeter,asc"
]
DOWNLOADER_MIDDLEWARES = {
'scrapy.contrib.downloadermiddleware.robotstxt.ROBOTSTXT_OBEY': True,
}
base_url = "http://www.imdb.com"
def parse(self, response):
movies = response.xpath("//*[#id='main']/table/tr/td[3]/a/#href")
for i in xrange(len(movies)):
l = self.base_url + movies[i].extract()
print l
request = scrapy.Request(l, callback=self.parse_movie)
yield request
next = response.xpath("//*[#id='right']/span/a")[-1]
next_url = self.base_url + next.xpath(".//#href")[0].extract()
next_text = next.xpath(".//text()").extract()[0][:4]
if next_text == "Next":
request = scrapy.Request(next_url, callback=self.parse)
yield request
'''
for sel in response.xpath("//table[#class='chart']/tbody/tr"):
url = urlparse.urljoin(response.url, sel.xpath("td[2]/a/#href").extract()[0].strip())
request = scrapy.Request(url, callback=self.parse_movie)
yield request
'''
def parse_movie(self, response):
movie = ImdbMovie()
i1 = response.url.find('/tt') + 1
i2 = response.url.find('?')
i2 = i2 - 1 if i2 > -1 else i2
movie['id'] = response.url[i1:i2]
movie['url'] = "http://www.imdb.com/title/" + movie['id']
r_tmp = response.xpath("//div[#class='titlePageSprite star-box-giga-star']/text()")
if r_tmp is None or r_tmp == "" or len(r_tmp) < 1:
return
movie['rating'] = int(float(r_tmp.extract()[0].strip()) * 10)
movie['title'] = response.xpath("//span[#itemprop='name']/text()").extract()[0]
movie['reviews_url'] = movie['url'] + "/reviews"
# Number of reviews associated with this movie
n = response.xpath("//*[#id='titleUserReviewsTeaser']/div/div[3]/a[2]/text()")
if n is None or n == "" or len(n) < 1:
return
n = n[0].extract().replace("See all ", "").replace(" user reviews", "")\
.replace(" user review", "").replace(",", "").replace(".", "").replace("See ", "")
if n == "one":
n = 1
else:
n = int(n)
movie['number_of_reviews'] = n
r = int(math.ceil(n / 10))
for x in xrange(1, r):
start = x * 10 - 10
url = movie['reviews_url'] + "?start=" + str(start)
request = scrapy.Request(url, callback=self.parse_review)
request.meta['movieObj'] = movie
yield request
def parse_review(self, response):
ranks = response.xpath("//*[#id='tn15content']/div")[0::2]
texts = response.xpath("//*[#id='tn15content']/p")
del texts[-1]
if len(ranks) != len(texts):
return
for i in xrange(0, len(ranks) - 1):
review = ImdbReview()
review['movieObj'] = response.meta['movieObj']
review['text'] = texts[i].xpath("text()").extract()
rating = ranks[i].xpath(".//img[2]/#src").re("-?\\d+")
if rating is None or rating == "" or len(rating) < 1:
return
review['rating'] = int(rating[0])
yield review
Can someone tell me where am I going wrong?
In my opinion, this web site should be load the list of movies use by js. Fristly, I suggest you should check the output about: movies = response.xpath("//*[#id='main']/table/tr/td[3]/a/#href"). If you want to get js content, you can use webkit in scrapy as a downloader middleware.

Adjacency list + Abstract Base Class Inheritance used in relationship

Following is a example for Adjacency List + Inheritance. This works as expected but if i try to use it in a another Model Mammut as a relationship it throws me this error:
Traceback (most recent call last):
File "bin/py", line 73, in <module>
exec(compile(__file__f.read(), __file__, "exec"))
File "../adjacency_list.py", line 206, in <module>
create_entries(IntTreeNode)
File "../adjacency_list.py", line 170, in create_entries
mut.nodes.append(node)
File "/home/xxx/.buildout/eggs/SQLAlchemy-0.9.8-py3.4-linux-x86_64.egg/sqlalchemy/orm/dynamic.py", line 304, in append
attributes.instance_dict(self.instance), item, None)
File "/home/xxx/.buildout/eggs/SQLAlchemy-0.9.8-py3.4-linux-x86_64.egg/sqlalchemy/orm/dynamic.py", line 202, in append
self.fire_append_event(state, dict_, value, initiator)
File "/home/xxx/.buildout/eggs/SQLAlchemy-0.9.8-py3.4-linux-x86_64.egg/sqlalchemy/orm/dynamic.py", line 99, in fire_append_event
value = fn(state, value, initiator or self._append_token)
File "/home/xxx/.buildout/eggs/SQLAlchemy-0.9.8-py3.4-linux-x86_64.egg/sqlalchemy/orm/attributes.py", line 1164, in emit_backref_from_collection_append_event
child_impl.append(
AttributeError: '_ProxyImpl' object has no attribute 'append'
The Code:
from sqlalchemy import (Column, ForeignKey, Integer, String, create_engine,
Float)
from sqlalchemy.orm import (Session, relationship, backref, joinedload_all)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm.collections import attribute_mapped_collection
from sqlalchemy.ext.declarative import declared_attr, AbstractConcreteBase
Base = declarative_base()
class Mammut(Base):
__tablename__ = "mammut"
id = Column(Integer, primary_key=True)
nodes = relationship(
'TreeNode',
backref='mammut',
lazy='dynamic',
cascade="all, delete-orphan",
#viewonly=True
)
class TreeNode(AbstractConcreteBase, Base):
id = Column(Integer, primary_key=True)
name = Column(String(50), nullable=False)
depth = Column(Integer, default=0)
data_type = Column(String(50))
#declared_attr
def mammut_id(cls):
return Column(Integer, ForeignKey('mammut.id'))
#declared_attr
def __tablename__(cls):
return cls.__name__.lower()
#declared_attr
def __mapper_args__(cls):
ret = {}
if cls.__name__ != "TreeNode":
ret = {'polymorphic_identity': cls.__name__,
'concrete': True,
# XXX redundant makes only sense if we use one table
'polymorphic_on': cls.data_type}
return ret
#declared_attr
def parent_id(cls):
_fid = '%s.id' % cls.__name__.lower()
return Column(Integer, ForeignKey(_fid))
#declared_attr
def children(cls):
_fid = '%s.id' % cls.__name__
return relationship(cls.__name__,
# cascade deletions
cascade="all, delete-orphan",
# many to one + adjacency list - remote_side
# is required to reference the 'remote'
# column in the join condition.
backref=backref("parent", remote_side=_fid),
# children will be represented as a dictionary
# on the "name" attribute.
collection_class=attribute_mapped_collection(
'name'),
)
def get_path(self, field):
if self.parent:
return self.parent.get_path(field) + [getattr(self, field)]
else:
return [getattr(self, field)]
#property
def name_path(self):
# XXX there is no way to query for it except we add a function with a
# cte (recursive query) to our database see [1] for it
# https://stackoverflow.com/questions/14487386/sqlalchemy-recursive-hybrid-property-in-a-tree-node
return '/'.join(self.get_path(field='name'))
def __init__(self, name, value=None, parent=None):
self.name = name
self.parent = parent
self.depth = 0
self.value = value
if self.parent:
self.depth = self.parent.depth + 1
def __repr__(self):
ret = "%s(name=%r, id=%r, parent_id=%r, value=%r, depth=%r, " \
"name_path=%s data_type=%s)" % (
self.__class__.__name__,
self.name,
self.id,
self.parent_id,
self.value,
self.depth,
self.name_path,
self.data_type
)
return ret
def dump(self, _indent=0):
return " " * _indent + repr(self) + \
"\n" + \
"".join([
c.dump(_indent + 1)
for c in self.children.values()]
)
class IntTreeNode(TreeNode):
value = Column(Integer)
class FloatTreeNode(TreeNode):
value = Column(Float)
miau = Column(String(50), default='zuff')
def __repr__(self):
ret = "%s(name=%r, id=%r, parent_id=%r, value=%r, depth=%r, " \
"name_path=%s data_type=%s miau=%s)" % (
self.__class__.__name__,
self.name,
self.id,
self.parent_id,
self.value,
self.depth,
self.name_path,
self.data_type,
self.miau
)
return ret
if __name__ == '__main__':
engine = create_engine('sqlite:///', echo=True)
def msg(msg, *args):
msg = msg % args
print("\n\n\n" + "-" * len(msg.split("\n")[0]))
print(msg)
print("-" * len(msg.split("\n")[0]))
msg("Creating Tree Table:")
Base.metadata.create_all(engine)
session = Session(engine)
def create_entries(Cls):
node = Cls('rootnode', value=2)
Cls('node1', parent=node)
Cls('node3', parent=node)
node2 = Cls('node2')
Cls('subnode1', parent=node2)
node.children['node2'] = node2
Cls('subnode2', parent=node.children['node2'])
msg("Created new tree structure:\n%s", node.dump())
msg("flush + commit:")
# XXX this throws the error
mut = Mammut()
mut.nodes.append(node)
session.add(mut)
session.add(node)
session.commit()
msg("Tree After Save:\n %s", node.dump())
Cls('node4', parent=node)
Cls('subnode3', parent=node.children['node4'])
Cls('subnode4', parent=node.children['node4'])
Cls('subsubnode1', parent=node.children['node4'].children['subnode3'])
# remove node1 from the parent, which will trigger a delete
# via the delete-orphan cascade.
del node.children['node1']
msg("Removed node1. flush + commit:")
session.commit()
msg("Tree after save:\n %s", node.dump())
msg("Emptying out the session entirely, "
"selecting tree on root, using eager loading to join four levels deep.")
session.expunge_all()
node = session.query(Cls).\
options(joinedload_all("children", "children",
"children", "children")).\
filter(Cls.name == "rootnode").\
first()
msg("Full Tree:\n%s", node.dump())
# msg("Marking root node as deleted, flush + commit:")
# session.delete(node)
# session.commit()
create_entries(IntTreeNode)
create_entries(FloatTreeNode)
nodes = session.query(TreeNode).filter(
TreeNode.name == "rootnode").all()
for idx, n in enumerate(nodes):
msg("Full (%s) Tree:\n%s" % (idx, n.dump()))
concrete inheritance can be very difficult, and AbstractConcreteBase itself has bugs in 0.9 which get in the way of elaborate mappings like this from being used.
Using 1.0 (not released, use git master), I can get the major elements going as follows:
from sqlalchemy import Column, String, Integer, create_engine, ForeignKey, Float
from sqlalchemy.orm import Session, relationship
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm.collections import attribute_mapped_collection
from sqlalchemy.ext.declarative import declared_attr, AbstractConcreteBase
Base = declarative_base()
class Mammut(Base):
__tablename__ = "mammut"
id = Column(Integer, primary_key=True)
nodes = relationship(
'TreeNode',
lazy='dynamic',
back_populates='mammut',
)
class TreeNode(AbstractConcreteBase, Base):
id = Column(Integer, primary_key=True)
name = Column(String)
#declared_attr
def __tablename__(cls):
if cls.__name__ == 'TreeNode':
return None
else:
return cls.__name__.lower()
#declared_attr
def __mapper_args__(cls):
return {'polymorphic_identity': cls.__name__, 'concrete': True}
#declared_attr
def parent_id(cls):
return Column(Integer, ForeignKey(cls.id))
#declared_attr
def mammut_id(cls):
return Column(Integer, ForeignKey('mammut.id'))
#declared_attr
def mammut(cls):
return relationship("Mammut", back_populates="nodes")
#declared_attr
def children(cls):
return relationship(
cls,
back_populates="parent",
collection_class=attribute_mapped_collection('name'),
)
#declared_attr
def parent(cls):
return relationship(
cls, remote_side="%s.id" % cls.__name__,
back_populates='children')
class IntTreeNode(TreeNode):
value = Column(Integer)
class FloatTreeNode(TreeNode):
value = Column(Float)
miau = Column(String(50), default='zuff')
e = create_engine("sqlite://", echo=True)
Base.metadata.create_all(e)
session = Session(e)
root = IntTreeNode(name='root')
IntTreeNode(name='n1', parent=root)
n2 = IntTreeNode(name='n2', parent=root)
IntTreeNode(name='n2n1', parent=n2)
m1 = Mammut()
m1.nodes.append(n2)
m1.nodes.append(root)
session.add(root)
session.commit()
session.close()
root = session.query(TreeNode).filter_by(name='root').one()
print root.children