Takes data in coco format
import detectron2 as detectron2
from detectron2.engine.hooks import HookBase
from detectron2.evaluation import inference_context
from detectron2.utils.logger import log_every_n_seconds
from detectron2.data import DatasetMapper, build_detection_test_loader
from detectron2.data.datasets import register_coco_instances
import detectron2
from detectron2.data.datasets import register_coco_instances
# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, VisImage
from detectron2.data import MetadataCatalog
from detectron2.data.catalog import DatasetCatalog
import numpy as np
import json
# json_file = '../Data/vineyard/labels_coco_vineyard.json'
json_file = '../Data/sim/sim.json'
image_root = '../Data/sim/images/'
# a = detectron2.data.datasets.load_coco_json(json_file,'row/')
#If need to modify the config file
# Opening JSON file
#f = open(json_file)
# returns JSON object as
# a dictionary
data = json.load(f)
# data['annotations'][0]['category_id']=1
# with open("sample.json", "w") as outfile:
# json.dump(data, outfile)
train_image = '../Data/vineyard/images'
train_json = '../Data/vineyard/labels_coco_vineyard.json'
register_coco_instances("my_dataset_train", {}, train_json, train_image)
my_dataset_train_metadata = MetadataCatalog.get("my_dataset_train")
dataset_dicts_train = DatasetCatalog.get("my_dataset_train")
import random
import cv2
from detectron2.utils.visualizer import Visualizer
#For visualisation
for d in random.sample(dataset_dicts_train, 3):
img = cv2.imread(d["file_name"])
visualizer = Visualizer(img[:, :, ::-1], metadata=my_dataset_train_metadata, scale=0.5)
vis = visualizer.draw_dataset_dict(d)
cv2.namedWindow("asd", cv2.WINDOW_NORMAL)
cv2.imshow('asd',vis.get_image()[:, :, ::-1])
cv2.waitKey(0)
cv2.destroyAllWindows()
#%%
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
import os
cfg = get_cfg()
cfg.merge_from_file(
"../../detectron2/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x.yaml"
)
cfg.OUTPUT_DIR = './out_mask_roza_apple/'
# cfg.INPUT.MASK_FORMAT = "bitmask"
cfg.DATASETS.TRAIN = ("my_dataset_train",)
cfg.DATASETS.TEST = () # no metrics implemented for this dataset
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = "detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl"
# cfg.MODEL.WEIGHTS = 'out_mask_roza_apple/model_final.pth'
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.001
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = (
128
) # faster, and good enough for this toy dataset
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # 3 classes (data, fig, hazelnut)
cfg.SOLVER.WARMUP_ITERS = 100
cfg.SOLVER.MAX_ITER =15000 #adjust up if val mAP is still rising, adjust down if overfit
cfg.SOLVER.STEPS = (10, 15)
cfg.SOLVER.GAMMA = 0.05
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=True)
trainer.train()
for inference
import os
import random
import cv2
# cfg = get_cfg()
# cfg.merge_from_file(
# "../../detectron2/detectron2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x.yaml"
# )
# cfg.OUTPUT_DIR = './out_mask/'
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
# cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set the testing threshold for this model
# cfg.DATASETS.TEST = ( )
predictor = DefaultPredictor(cfg)
from detectron2.utils.visualizer import ColorMode
import matplotlib.pyplot as plt
test_path = '/home/achyut/Achyut/Projects/ResearchWorks/trunk_detection//train/'
tst = os.listdir(test_path)
d= '1801.png'
for d in random.sample(tst, 10):
sc= 2
im = cv2.imread(test_path+d)
outputs = predictor(im)
v = Visualizer(im[:, :, ::-1],
metadata=my_dataset_train_metadata,
scale=1,
# remove the colors of unsegmented pixels
)
v = v.draw_instance_predictions(outputs["instances"][0].to("cpu"))
img = v.get_image()[:, :, ::-1]
cv2.namedWindow("asd", cv2.WINDOW_NORMAL)
cv2.imshow('asd',img)
cv2.waitKey(0)
cv2.destroyAllWindows()