Plotly interactive Multiplot is not working in Google Colab - deep-learning

Hello I'm using Google colab to display multiplot that's why I tried to create one scatterplot with 3 scatter in it and under that, there is a parallel categories plot that should be connected to scatter but when I try to select a group of points in scatter, I cannot change color in parcat it's always grey, I think those functions such as on_click and on_selection does not work but I couldn't replace them with another code can someone help me?
import plotly.graph_objects as go
from ipywidgets import widgets
import pandas as pd
import numpy as np
# Build parcats dimensions
#here dfTot is my big class include every layer(input,1,2,3,output)
#Check image under the Code
dfTot = pd.concat([df1,df2,df3,df4], axis=1,join='inner')
categorical_dimensions = ['Layer Input', 'Layer 2', 'Layer 3','Layer Output'];
dimensions = [dict(values=dfTot[label], label=label) for label in categorical_dimensions]
# Build colorscale
color = np.zeros(len(dfTot), dtype='uint8')
colorscale = [[0, 'gray'], [1, 'firebrick']]
# Build figure as FigureWidget
fig = go.FigureWidget(
data=[
go.Scatter(x=layer_activation1[:,0], y=layer_activation1[:,1],showlegend = False,hovertemplate =y+"<br>"+"Cluster Groupe : "+df2['Layer 2']+"<br>"+"Layer 2"'<extra></extra>',
marker={'color': 'gray'}, mode='markers', selected={'marker': {'color': 'firebrick'}},
unselected={'marker': {'opacity': 0.3}}),
go.Scatter(x=layer_activation2[:,0], y=layer_activation2[:,1],showlegend = False,hovertemplate =y+"<br>"+"Cluster Groupe : "+df3['Layer 3']+"<br>"+"Layer 3"'<extra></extra>',
marker={'color': 'gray'}, mode='markers', selected={'marker': {'color': 'firebrick'}},
unselected={'marker': {'opacity': 0.3}}),
go.Scatter(x=layer_activation3[:,0], y=layer_activation3[:,1],showlegend = False,hovertemplate =y+"</br>"+"Cluster Groupe : "+df4['Layer Output']+"<br>"+"Layer Output"'<extra></extra>',
marker={'color': 'gray'}, mode='markers', selected={'marker': {'color': 'firebrick'}},
unselected={'marker': {'opacity': 0.3}}),
go.Parcats(
domain={'y': [0, 0.4]}, dimensions=dimensions,
line={'colorscale': colorscale, 'cmin': 0,
'cmax': 1, 'color': color, 'shape': 'hspline'})
])
fig.update_layout(
height=800, xaxis={'title': 'Axis x'},
yaxis={'title': 'Axis y', 'domain': [0.6, 1]},
dragmode='lasso', hovermode='closest')
# Update color callback
def update_color(trace, points, state):
new_color = np.zeros(len(dfTot), dtype='uint8')
# Update scatter selection
fig.data[0].selectedpoints = points.point_inds
new_color[points.point_inds] = 1
fig.data[3].line.color = new_color
fig.data[1].selectedpoints = points.point_inds
new_color[points.point_inds] = 1
fig.data[3].line.color = new_color
fig.data[2].selectedpoints = points.point_inds
# Update parcats colors
new_color[points.point_inds] = 1
fig.data[3].line.color = new_color
# Register callback on scatter selection...
fig.data[0].on_selection(update_color)
fig.data[1].on_selection(update_color)
fig.data[2].on_selection(update_color)
# and parcats click
fig.data[3].on_click(update_color)
Here is the screenshot of plot

Related

Clearing a matlibplot from a Tkinter GUI

The image shows the GUI:
I am making a GUI with Tkinter which can show graphs which have been plotted with matplotlib.
The graphs get their x and y values from a JSON file.
The data from the JSON file gets collected by some code which I wrote (this code is included in the code below)
What works: The collecting of data from the JSON-file, plotting the data in a graph via matplotlib, and showing this graph on a tkinter canvas all works fine.
Problem: I cannot clear the canvas, such that I can display another graph based on data from another JSON file. I have to close the program, start it again, and select a different JSON file, if I want to see another graph.
Here is the code:
################################### GUI program for graphs #######################################
##### Scrollbar for the listbox #####
Myframe = Frame(inspect_data)
Myframe.pack(side=tk.LEFT, fill=Y)
my_scrollbar = Scrollbar(Myframe, orient=VERTICAL)
#####Listbox#####
vores_listebox = Listbox(Myframe, width=22, height=40,
yscrollcommand=my_scrollbar.set)
vores_listebox.pack(side=LEFT, expand=True)
my_scrollbar.config(command=vores_listebox.yview)
my_scrollbar.pack(side=LEFT, fill=Y)
Myframe.pack()
###### Button functions ######
path_us = ""
def vis_NS():
vores_listebox.delete(0, END)
# Husk at ændre mapper##
global path_us
path_us = "C://Users//canal//OneDrive//Dokumenter//AAU//3. semester//1//"
for x in os.listdir(path_us):
if x.endswith(".json"):
vores_listebox.insert(END, x)
def vis_OS():
global path_us
path_us = "C://Users//canal//OneDrive//Dokumenter//AAU//3. semester//2//"
vores_listebox.delete(0, END)
# De NS;OS,US,MS skal ligge i hver sin mappe#
for x in os.listdir(path_us):
if x.endswith(".json"):
vores_listebox.insert(END, x)
def vis_US():
vores_listebox.delete(0, END)
for x in os.listdir("C://Users//canal//OneDrive//Dokumenter//AAU//3. semester//2"):
if x.endswith(".json"):
vores_listebox.insert(END, x)
def vis_MS():
vores_listebox.delete(0, END)
for x in os.listdir("C://Users//canal//OneDrive//Dokumenter//AAU//3. semester//2"):
if x.endswith(".json"):
vores_listebox.insert(END, x)
#### Buttons #####
button1 = Button(inspect_data,
text="Normal screwing",
width=17,
height=2,
bg="white",
fg="black",
command=vis_NS
)
button1.place(x=5, y=10)
button2 = Button(inspect_data,
text="Over screwing",
width=17,
height=2,
bg="white",
fg="black",
command=vis_OS
)
button2.place(x=5, y=50)
button3 = Button(inspect_data,
text="Under screwing",
width=17,
height=2,
bg="white",
fg="black",
command=vis_US
)
button3.place(x=5, y=90)
button4 = Button(inspect_data,
text="Missing screw",
width=17,
height=2,
bg="white",
fg="black",
command=vis_MS
)
button4.place(x=5, y=130)
######en fil bliver trykket#####
# placing the canvas on the Tkinter window
def items_selected(event):
selected_indices = vores_listebox.curselection()
selected_json = ",".join([vores_listebox.get(i) for i in selected_indices])
full_file_path = path_us + selected_json
open_json = js.load(open(full_file_path, "r"))
time = [open_json['XML_Data']
['Wsk3Vectors']['X_Axis']['Values']['float']]
rpm = [open_json['XML_Data']['Wsk3Vectors']
['Y_AxesList']['AxisData'][0]['Values']['float']]
torque = [open_json['XML_Data']['Wsk3Vectors']
['Y_AxesList']['AxisData'][1]['Values']['float']]
current = [open_json['XML_Data']['Wsk3Vectors']
['Y_AxesList']['AxisData'][2]['Values']['float']]
angle = [open_json['XML_Data']['Wsk3Vectors']
['Y_AxesList']['AxisData'][3]['Values']['float']]
depth = [open_json['XML_Data']['Wsk3Vectors']
['Y_AxesList']['AxisData'][4]['Values']['float']]
####### Using Matlib.pyplot to plot 5 graphs #######
plt.rcParams["figure.figsize"] = (7, 10)
plt.subplot(5, 1, 1)
plt.scatter(time, rpm, c="b", linewidths=2,
marker=",", edgecolor="b", s=1, alpha=0.5)
plt.title(selected_json)
plt.gca().axes.xaxis.set_ticklabels([])
plt.ylabel("RPM")
plt.grid()
plt.subplot(5, 1, 2)
plt.scatter(time, torque, c="g", linewidths=1,
marker=",", edgecolor="g", s=1, alpha=0.3)
plt.gca().axes.xaxis.set_ticklabels([])
plt.ylabel("Torque [Nm]")
plt.grid()
plt.subplot(5, 1, 3)
plt.scatter(time, current, c="r", linewidths=2,
marker=",", edgecolor="r", s=1, alpha=0.5)
plt.gca().axes.xaxis.set_ticklabels([])
plt.ylabel("Current [Amps]")
plt.grid()
plt.subplot(5, 1, 4)
plt.scatter(time, angle, c="m", linewidths=2,
marker=",", edgecolor="m", s=1, alpha=0.5)
plt.gca().axes.xaxis.set_ticklabels([])
plt.ylabel("Angle [RAD]")
plt.grid()
plt.subplot(5, 1, 5)
plt.scatter(time, depth, c="c", linewidths=2,
marker=",", edgecolor="c", s=1, alpha=0.5)
plt.xlabel("Time [ms]")
plt.ylabel("Depth [mm]")
plt.grid()
#### Sowing all the subplots in a tkinter canvas ########
fig = plt.figure()
canvas = FigureCanvasTkAgg(fig, master=inspect_data)
canvas.get_tk_widget().pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
# toolbar = matplotlib.NavigationToolbar2TkAgg(
# canvas, self)
# toolbar.update()
canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)
vores_listebox.bind('<<ListboxSelect>>', items_selected)
root.after(1000, converter)
root.mainloop()
I have tried canvas.delete(all), plt.clf() among other things. One solution I think could work is:
If I start the function def items_selected to clear the figure inside the canvas by using: plt.clf(fig), however "fig" is not defined at this point, so python wouldn't know what this means.
it seems to me that the problem originates from plotting to plt, and not directly onto some axis from a figure linked to tkinter. Try to bind your subplot figure directly to your canvas outside the items_selected() function, and then redraw the canvas whenever you change the underlying matplotlib figure:
fig, axes = plt.subplots(5, 1, figsize=(7, 10))
canvas = FigureCanvasTkAgg(fig, inspect_data)
canvas.get_tk_widget().pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
def items_selected(event):
.....
axes[0].scatter(time, rpm, c="b", linewidths=2,
marker=",", edgecolor="b", s=1, alpha=0.5)
axes[0].set_title(selected_json)
axes[0].set_xticklabels([])
axes[0].set_ylabel("RPM")
axes[0].grid()
....
canvas.draw()

Issues in creating pysimplegui or Tkinter graph GUI by reading csv file , cleaning it and plotting graph (histogram+PDF)

I want to create GUI which should automatically clean data in csv file once selected and plot superimposed PDF & histogram graph. I have uploaded basic python program which generates the required graph but I am unbale to convert it into interface. I guess, only "open file" & "plot" buttons would suffice the requirement. image- want to retrieve data from 'N'th column (13) only with skipping top 4 rows
I am basically from metallurgy background and trying my hands in this field.
Any help would be greatly appreciated
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
raw_data = pd.read_csv("D:/Project/Python/NDC/Outlier_ND/800016_DAT.csv",skiprows=4,header=None)
clean = pd.DataFrame(raw_data)
data1 = clean.iloc[:, [13]]
Q1 = data1.quantile(0.25)
Q3 = data1.quantile(0.75)
IQR = Q3 - Q1
data_IQR = data1[~((data1 < (Q1 - 1.5 * IQR)) |(data1 > (Q3 + 1.5 * IQR))).any(axis=1)]
data_IQR.shape
print(data1.shape)
print(data_IQR.shape)
headerList = ['Actual_MR']
data_IQR.to_csv(r'D:\Project\Python\NDC\Outlier_ND\800016_DAT_IQR.csv', header=headerList, index=False)
data = pd.read_csv("D:/Project/Python/NDC/Outlier_ND/800016_DAT_IQR.csv")
mean, sd = norm.fit(data)
plt.hist(data, bins=25, density=True, alpha=0.6, facecolor = '#2ab0ff', edgecolor='#169acf', linewidth=0.5)
xmin, xmax = plt.xlim()
x = np.linspace(xmin, xmax, 100)
p = norm.pdf(x, mean, sd)
plt.plot(x, p, 'red', linewidth=2)
title = " Graph \n mean: {:.2f} and SD: {:.2f}".format(mean, sd)
plt.title(title)
plt.xlabel('MR')
plt.ylabel('Pr')
plt.show()
Following code demo how PySimpleGUI to work with matplotlib, detail please find all remark in script.
import math, random
from pathlib import Path
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
import PySimpleGUI as sg
# 1. Define the class as the interface between matplotlib and PySimpleGUI
class Canvas(FigureCanvasTkAgg):
"""
Create a canvas for matplotlib pyplot under tkinter/PySimpleGUI canvas
"""
def __init__(self, figure=None, master=None):
super().__init__(figure=figure, master=master)
self.canvas = self.get_tk_widget()
self.canvas.pack(side='top', fill='both', expand=1)
# 2. create PySimpleGUI window, a fixed-size Frame with Canvas which expand in both x and y.
font = ("Courier New", 11)
sg.theme("DarkBlue3")
sg.set_options(font=font)
layout = [
[sg.Input(expand_x=True, key='Path'),
sg.FileBrowse(file_types=(("ALL CSV Files", "*.csv"), ("ALL Files", "*.*"))),
sg.Button('Plot')],
[sg.Frame("", [[sg.Canvas(background_color='green', expand_x=True, expand_y=True, key='Canvas')]], size=(640, 480))],
[sg.Push(), sg.Button('Exit')]
]
window = sg.Window('Matplotlib', layout, finalize=True)
# 3. Create a matplotlib canvas under sg.Canvas or sg.Graph
fig = Figure(figsize=(5, 4), dpi=100)
ax = fig.add_subplot()
canvas = Canvas(fig, window['Canvas'].Widget)
# 4. initial for figure
ax.set_title(f"Sensor Data")
ax.set_xlabel("X axis")
ax.set_ylabel("Y axis")
ax.set_xlim(0, 1079)
ax.set_ylim(-1.1, 1.1)
ax.grid()
canvas.draw() # do Update to GUI canvas
# 5. PySimpleGUI event loop
while True:
event, values = window.read()
if event in (sg.WINDOW_CLOSED, 'Exit'):
break
elif event == 'Plot':
"""
path = values['Path']
if not Path(path).is_file():
continue
"""
# 6. Get data from path and plot from here
ax.cla() # Clear axes first if required
ax.set_title(f"Sensor Data")
ax.set_xlabel("X axis")
ax.set_ylabel("Y axis")
ax.grid()
theta = random.randint(0, 359)
x = [degree for degree in range(1080)]
y = [math.sin((degree+theta)/180*math.pi) for degree in range(1080)]
ax.plot(x, y)
canvas.draw() # do Update to GUI canvas
# 7. Close window to exit
window.close()

How to extend ‘text’ field in Graph.ExtendData with Plotly-Dash?

I have an animated graph that I update in a clientside callback. However, I want to update the text as well as the x and y values of the traces in Graph.extendData(), but it seems that that doesn't work. Is there something I'm missing? Alternatively, is there a different method I should be using instead?
Adopting the code from this post (Plotly/Dash display real time data in smooth animation), I'd like something like this, but where updating the text with extendData actually worked:
import dash
import dash_html_components as html
import dash_core_components as dcc
import numpy as np
from dash.dependencies import Input, Output, State
# Example data (a circle).
resolution = 1000
t = np.linspace(0, np.pi * 2, resolution)
x, y = np.cos(t), np.sin(t)
text = str(t)
# Example app.
figure = dict(data=[{'x': [], 'y': []}], text = [], layout=dict(xaxis=dict(range=[-1, 1]), yaxis=dict(range=[-1, 1])))
app = dash.Dash(__name__, update_title=None) # remove "Updating..." from title
app.layout = html.Div([
dcc.Graph(id='graph', figure=dict(figure)), dcc.Interval(id="interval", interval=25),
dcc.Store(id='offset', data=0), dcc.Store(id='store', data=dict(x=x, y=y, text=text, resolution=resolution)),
])
# This makes the graph fail to draw instead of just extending the text as wel!
app.clientside_callback(
"""
function (n_intervals, data, offset) {
offset = offset % data.x.length;
const end = Math.min((offset + 10), data.x.length);
return [[{x: [data.x.slice(offset, end)], y: [data.y.slice(offset, end)], text: [data.text.slice(offset, end)]}, [0], 500], end]
}
""",
[Output('graph', 'extendData'), Output('offset', 'data')],
[Input('interval', 'n_intervals')], [State('store', 'data'), State('offset', 'data')]
)
if __name__ == '__main__':
app.run_server()
Alternatively, is there a different method I should be using instead?

unknown resampling filter error when trying to create my own dataset with pytorch

I am trying to create a CNN implemented with data augmentation in pytorch to classify dogs and cats. The issue that I am having is that when I try to input my dataset and enumerate through it I keep getting this error:
Traceback (most recent call last):
File "<ipython-input-55-6337e0536bae>", line 75, in <module>
for i, (inputs, labels) in enumerate(trainloader):
File "/usr/local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 188, in __next__
batch = self.collate_fn([self.dataset[i] for i in indices])
File "/usr/local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 188, in <listcomp>
batch = self.collate_fn([self.dataset[i] for i in indices])
File "/usr/local/lib/python3.6/site-packages/torchvision/datasets/folder.py", line 124, in __getitem__
img = self.transform(img)
File "/usr/local/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 42, in __call__
img = t(img)
File "/usr/local/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 147, in __call__
return F.resize(img, self.size, self.interpolation)
File "/usr/local/lib/python3.6/site-packages/torchvision/transforms/functional.py", line 197, in resize
return img.resize((ow, oh), interpolation)
File "/usr/local/lib/python3.6/site-packages/PIL/Image.py", line 1724, in resize
raise ValueError("unknown resampling filter")
ValueError: unknown resampling filter
and I really dont know whats wrong with my code. I have provided the code below:
# Creating the CNN
# Importing the libraries
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision
from torchvision import transforms
#Creating the CNN Model
class CNN(nn.Module):
def __init__(self, nb_outputs):
super(CNN, self).__init__() #activates the inheritance and allows the use of all the tools in the nn.Module
#making the 3 convolutional layers that will be used in the convolutional neural network
self.convolution1 = nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size = 5) #kernal_size -> the deminson of the feature detector e.g kernel_size = 5 => feature detector of size 5x5
self.convolution2 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 2)
#making 2 full connections one to connect the inputs of the ANN to the hidden layer and another to connect the hidden layer to the outputs of the ANN
self.fc1 = nn.Linear(in_features = self.count_neurons((1, 64,64)), out_features = 40)
self.fc2 = nn.Linear(in_features = 40, out_features = nb_outputs)
def count_neurons(self, image_dim):
x = Variable(torch.rand(1, *image_dim)) #this variable repersents a fake image to allow us to compute the number of neruons
#in order to pass the elements of the tuple image_dim into our function as a list of arguments we need to add a * before image_dim
#since x will be going into our neural network we need to convert it into a torch variable using the Variable() function
x = F.relu(F.max_pool2d(self.convolution1(x), 3, 2)) #first we apply the convolution to x then apply max_pooling to the convolutional fake images and then activate all the neurons in the pooling layer
x = F.relu(F.max_pool2d(self.convolution2(x), 3, 2)) #the signals are now propragated up to the thrid convoulational layer
#Now to flatten x to obtain the number of neurons in the flattening layer
return x.data.view(1, -1).size(1) #this will flatten x into a huge vector and returns the size of the vector, that size repersents the number of neurons that will be inputted into the ANN
#even though x is not a real image from the game since the size of the flattened vector only depends on the dimention of the inputted image we can just set x to have the same dimentions as the image
def forward(self, x):
x = F.relu(F.max_pool2d(self.convolution1(x), 3, 2)) #first we apply the convolution to x then apply max_pooling to the convolutional fake images and then activate all the neurons in the pooling layer
x = F.relu(F.max_pool2d(self.convolution2(x), 3, 2))
#flattening layer of the CNN
x = x.view(x.size(0), -1)
#x is now the inputs to the ANN
x = F.relu(self.fc1(x)) #we propagte the signals from the flatten layer to the full connected layer and activate the neruons by breaking the linearilty with the relu function
x = F.sigmoid(self.fc2(x))
#x is now the output neurons of the ANN
return x
train_tf = transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.Resize(64,64),
transforms.RandomRotation(20),
transforms.RandomGrayscale(.2),
transforms.ToTensor()])
test_tf = transforms.Compose([transforms.Resize(64,64),
transforms.ToTensor()])
training_set = torchvision.datasets.ImageFolder(root = './dataset/training_set',
transform = train_tf)
test_set = torchvision.datasets.ImageFolder(root = './dataset/test_set',
transform = transforms.Compose([transforms.Resize(64,64),
transforms.ToTensor()]) )
trainloader = torch.utils.data.DataLoader(training_set, batch_size=32,
shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(test_set, batch_size= 32,
shuffle=False, num_workers=0)
#training the model
cnn = CNN(1)
cnn.train()
loss = nn.BCELoss()
optimizer = optim.Adam(cnn.parameters(), lr = 0.001) #the optimizer => Adam optimizer
nb_epochs = 25
for epoch in range(nb_epochs):
train_loss = 0.0
train_acc = 0.0
total = 0.0
for i, (inputs, labels) in enumerate(trainloader):
inputs, labels = Variable(inputs), Variable(labels)
cnn.zero_grad()
outputs = cnn(inputs)
loss_error = loss(outputs, labels)
optimizer.step()
_, pred = torch.max(outputs.data, 1)
total += labels.size(0)
train_loss += loss_error.data[0]
train_acc += (pred == labels).sum()
train_loss = train_loss/len(training_loader)
train_acc = train_acc/total
print('Epoch: %d, loss: %.4f, accuracy: %.4f' %(epoch+1, train_loss, train_acc))
The folder arrangement for the code is /dataset/training_set and inside the training_set folder are two more folders one for all the cat images and the other for all the dog images. Each image is name either dog.xxxx.jpg or cat.xxxx.jpg, where the xxxx represents the number so for the first cat image it would be cat.1.jpg up to cat.4000.jpg. This is the same format for the test_set folder. The number of training images is 8000 and the number of test images is 2000. If anyone can point out my error I would greatly appreciate it.
Thank you
Try to set the desired size in transforms.Resize as a tuple:
transforms.Resize((64, 64))
PIL is using the second argument (in your case 64) as the interpolation method.
in torchvision.transforms.Compose([put every transform in these brackets]),
This, will not give the error.

The Tensorflow Object_detection API 's visualize don't work

when I am using the API of Object_detection,I followed the instruction ,everything is fine .However ,when I begin to test my picture , I met a problem , it seems that the function named
" visualize_boxes_and_labels_on_image_array " ( in the 57 line ) didn't work . Here is my source codes
import cv2
import numpy as np
import tensorflow as tf
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
class TOD(object):
def __init__(self):
self.PATH_TO_CKPT = '/home/xiyou/Desktop/ssd_training/result/frozen_inference_graph.pb'
self.PATH_TO_LABELS = '/home/xiyou/Desktop/ssd_training/detection_for_smoke.pbtxt'
self.NUM_CLASSES = 1
self.detection_graph = self._load_model()
self.category_index = self._load_label_map()
def _load_model(self):
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
return detection_graph
def _load_label_map(self):
label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map,
max_num_classes=self.NUM_CLASSES,
use_display_name=True)
category_index = label_map_util.create_category_index(categories)
return category_index
def detect(self, image):
with self.detection_graph.as_default():
with tf.Session(graph=self.detection_graph) as sess:
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image, axis=0)
image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')
# Actual detection.
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
print(boxes, scores, classes, num_detections)
#print(np.squeeze(boxes))
# Visualization of the results of a detection.
#######Here is the problem
# image1 = vis_util.visualize_boxes_and_labels_on_image_array(
image, #######Here is the problem
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
self.category_index,
use_normalized_coordinates=True,
line_thickness=50,
)
#print(np.squeeze(boxes),np.squeeze(classes))
cv2.namedWindow("detection")
cv2.imshow("detection", image1)
cv2.waitKey(0)
if __name__ == '__main__':
image = cv2.imread('/home/xiyou/Pictures/timg1.jpg')
detecotr = TOD()
detecotr.detect(image)
when I run this code , the image did show ,but nothing changed , no detected area in the pic and no an other informations . the input pic is the same as the out image . But when I was Debug , I found the Varibles such as soucres , classes , boxes do have values.
Is anyone can help me ? Thanks!!!
And my Tensorflow version is 1.4.0 , CUDA 8.0 in Ubuntu 16.04