diff --git a/gamemodel.py b/gamemodel.py index aefc27f..fec0ba6 100644 --- a/gamemodel.py +++ b/gamemodel.py @@ -2,9 +2,8 @@ import gameio import cv2 import numpy as np -def squared_distance(vec1, vec2): - """returns distance-squared between two x, y point tuples""" - return (vec1[0] - vec2[0])**2 + (vec1[1] - vec2[1])**2 +from utility import * +import pointcluster class GameModel: """Platform-independent representation of the game's state.""" @@ -19,19 +18,21 @@ class GameModel: ("ship_off", cv2.imread("images/game_assets/spaceship-off.png", 0)), ("ship_on", cv2.imread("images/game_assets/spaceship-on.png", 0)) ] + #self.missile = ("missile", cv2.imread("images/game_assets/missile.png", 0)) self.frame = None self.cv_template_thresh = 0.6 # reconfigurable at runtime - self.duplicate_dist_thresh = 10 + self.duplicate_dist_thresh = 36 def with_frame(fn): """Decorator to process screenshot to cv2 format once upon first requirement, then reuse.""" def inner(self, *args, **kwargs): if self.frame is None: - print("Fetching frame.") + #print("Fetching frame.") sshot = self.gameio.fetch_sshot() open_cv_image = np.array(sshot) # Convert RGB to BGR self.frame = open_cv_image[:, :, ::-1].copy() + self.color_frame = np.copy(self.frame) self.frame = cv2.cvtColor(self.frame, cv2.COLOR_BGR2GRAY) return fn(self, *args, **kwargs) return inner @@ -52,32 +53,80 @@ class GameModel: return asteroid_rects @with_frame - def display_results(self, results): + def display_results(self, rects = [], pointsets = [], circles = []): """Draws results on the current frame for test purposes.""" - displayable = np.copy(self.frame) - for pt, wh, label in results: - cv2.rectangle(displayable, pt, wh, 255, 1) + displayable = np.copy(self.color_frame) + for pt, wh, label in rects: + color = { "big": (255, 0, 0), + "normal": (0, 255, 0), + "small": (0, 0, 255), + "missile": (128, 0, 0), + "ship_on": (0, 0, 128), + "ship_off": (0, 64, 128)}[label] + cv2.rectangle(displayable, pt, wh, color, 1) cv2.putText(displayable, label, pt, cv2.FONT_HERSHEY_PLAIN, - 1.0, 255) + 1.0, color) + for ps in pointsets: + color = (0, 255, 255) + cv2.polylines(displayable, np.int32([ps]), True, color) + + for center, radius, label in circles: + color = (255, 255, 0) + cv2.circle(displayable, np.int32(center), int(radius), color, 1) + cv2.putText(displayable, label, np.int32(center), + cv2.FONT_HERSHEY_PLAIN, + 1.0, color) + cv2.imshow("Results", displayable) cv2.waitKey(0) @with_frame - def find_ships(self): + def frame_sift(self): sift = cv2.SIFT_create() - frame_kp, frame_desc = sift.detectAndCompute(self.frame, None) - kp_desc = [] # list of (keypoints, descriptions) for all ship sprites - for label, s in self.ships: - kp_desc.append((label, sift.detectAndCompute(s, None))) - bf = cv2.BFMatcher(cv2.NORM_L1, crossCheck=True) - matchsets = [] - for label, kpdesc in kp_desc: - _, desc = kpdesc - matchsets.append((label, bf.match(frame_desc, desc))) - return { "matchsets": matchsets, - "kp_desc": kp_desc - } + kp_desc = {} # dict of (keypoints, descriptions) for all ship sprites + kp_desc["frame"] = sift.detectAndCompute(self.frame, None) + frame_kp, frame_desc = kp_desc["frame"] +## for label, s in self.ships: +## kp_desc[label] = sift.detectAndCompute(s, None) +## bf = cv2.BFMatcher(cv2.NORM_L1, crossCheck=True) +## matchsets = {} +## for label in kp_desc: +## _, desc = kp_desc[label] +## matchsets[label] = bf.match(frame_desc, desc) +## #return { "matchsets": matchsets, +## # "kp_desc": kp_desc +## # } + ship_rsq = rect_radius_squared(*self.ships[0][1].shape) + #print(f"max radius^2: {ship_rsq}") + clusters = pointcluster.cluster_set([k.pt for k in frame_kp], sqrt(ship_rsq)) + + return clusters + + @with_frame + def find_ships(self): + ship_rects = [] + for label, a in self.ships: + h, w = a.shape + res = cv2.matchTemplate(self.frame, a, cv2.TM_CCOEFF_NORMED) + loc = np.where( res >= self.cv_template_thresh) + for pt in zip(*loc[::-1]): + if not ship_rects or squared_distance(ship_rects[-1][0], pt) > self.duplicate_dist_thresh: + ship_rects.append((pt, (pt[0] + w, pt[1] + h), label)) + return ship_rects + +## @with_frame +## def find_missiles(self): +## """This technique does not work for the 9x9 pixel missile image.""" +## missile_rects = [] +## label, img = self.missile +## h, w = img.shape +## res = cv2.matchTemplate(self.frame, img, cv2.TM_CCOEFF_NORMED) +## loc = np.where( res >= self.cv_template_thresh) +## for pt in zip(*loc[::-1]): +## if not missile_rects or squared_distance(missile_rects[-1][0], pt) > self.duplicate_dist_thresh: +## missile_rects.append((pt, (pt[0] + w, pt[1] + h), label)) +## return missile_rects if __name__ == '__main__': import platform @@ -99,8 +148,14 @@ if __name__ == '__main__': io.loc = pyscreeze.Box(0, 25, 800, 599) #input("Press to detect asteroids on screen.") - results = gm.find_asteroids() - print(f"Found {len(results)} asteroids") - for a in results: - print(a[0]) # position tuple - gm.display_results(results) + a_results = gm.find_asteroids() + print(f"Found {len(a_results)} asteroids") + #for a in a_results: + # print(a[0]) # position tuple + #gm.display_results(results) + s_results = gm.frame_sift() + ship_results = gm.find_ships() + polygons = [c.points for c in s_results] + #circles = [(c.center, c.max_distance, f"cluster_{i}") for i, c in enumerate(s_results)] + r_circles = [(c.center, sqrt(rect_radius_squared(*gm.ships[0][1].shape)), f"cluster_{i}") for i, c in enumerate(s_results)] + gm.display_results(rects=a_results+ship_results, pointsets=polygons, circles=r_circles) diff --git a/pointcluster.py b/pointcluster.py new file mode 100644 index 0000000..e89719d --- /dev/null +++ b/pointcluster.py @@ -0,0 +1,58 @@ +from utility import * + +class PointCluster: + def __init__(self): + self.points = [] + self.center = (0, 0) + self.max_distance = None + + def update(self): + if len(self.points) == 0: return + self.center = (sum([p[0] for p in self.points]) / len(self.points), + sum([p[1] for p in self.points]) / len(self.points)) + self.max_distance = sqrt(max( + [squared_distance(self.center, p) for p in self.points])) + + def add(self, pt): + self.points.append(pt) + self.update() + + def pop(self): + p = self.points.pop(-1) + self.update() + return p + + def __repr__(self): + c = f"({self.center[0]:.1f},{self.center[1]:.1f})" + return f"" + +def cluster_set(points, maxradius): + """returns a list of PointCluster objects. Points are fit within circles of maxradius""" + clusters = [] + for pt in points: + if len(clusters) == 0: + #print("first cluster") + clusters.append(PointCluster()) + clusters[-1].add(pt) + continue + # add point to its nearest cluster + scored_clusters = [(c, squared_distance(pt, c.center)) for c in clusters] + scored_clusters.sort(key=lambda i: i[1]) + winner = scored_clusters[0][0] + winner.add(pt) + + # if maxradius constraint was violated, pop the newest point & add new cluster + if winner.max_distance > maxradius: + #print(f"{winner.max_distance} > {maxradius}; new cluster") + winner.pop() + clusters.append(PointCluster()) + clusters[-1].add(pt) + + # refine step - accept centers as fixed, put points in closest center + new_clusters = {c.center: PointCluster() for c in clusters} + closest = lambda pt: sorted(new_clusters.keys(), key= lambda i: squared_distance(pt, i))[0] + for point in points: + new_clusters[closest(point)].add(point) + #print(clusters) + #print(new_clusters.values()) + return new_clusters.values() diff --git a/utility.py b/utility.py new file mode 100644 index 0000000..599c38e --- /dev/null +++ b/utility.py @@ -0,0 +1,9 @@ +from math import sqrt + +def squared_distance(vec1, vec2): + """returns distance-squared between two x, y point tuples""" + return (vec1[0] - vec2[0])**2 + (vec1[1] - vec2[1])**2 + +def rect_radius_squared(w, h): + """Returns the radius^2 of the circle inscribed in a rectangle of w * h""" + return (w/2)**2 + (h/2)**2