Gitlab Community Edition Instance

Skip to content
Snippets Groups Projects
Commit d7fd0e1c authored by Dorothea Sommer's avatar Dorothea Sommer
Browse files

refactor dataloader

parent 90901ae7
Branches
No related tags found
No related merge requests found
......@@ -24,6 +24,34 @@ def fix_random_seed(seed: int, device=None) -> None:
torch.cuda.manual_seed_all(seed)
# For now, do not set torch.backends.cudnn.deterministic to True and cudnn.benchmark to False (it is faster without).
def create_dataloader(data_folder : str, verbose=True):
"""Return two dataloaders, one for train and one for validation."""
transformations = transforms.Compose(
[SamplePoints(1024, sample_method="random")])
# It would be also possible to sample the farthest points:
# transformations = transforms.Compose([SamplePoints(1024, sample_method = "farthest_points")])
data = PointCloudDataSet(data_folder, train=True,
transform=transformations)
validation_percentage = 0.2
dataset_size = len(data)
idx = list(range(dataset_size))
split = int(np.floor(validation_percentage * dataset_size))
np.random.shuffle(idx)
train_idx, val_idx = idx[split:], idx[:split]
if verbose:
print(f"Training is done with {len(train_idx)} samples.")
print(f"Validation is done with {len(val_idx)} samples.")
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
train_loader = DataLoader(data, batch_size=batch_size, sampler=train_sampler)
val_loader = DataLoader(data, batch_size=batch_size, sampler=val_sampler)
return train_loader, val_loader
if __name__ == "__main__":
print("Start training")
......@@ -38,38 +66,16 @@ if __name__ == "__main__":
print(f"Created path for models in {saved_models_path}")
learning_rate = 0.0005
batch_size = 8
### FIX SEEDS ###
fix_random_seed(26, device=device)
batch_size = 8
### DATA ####
data_folder = "/scratch/projects/forestcare/data/workshop/synthetic_trees_ten"
transformations = transforms.Compose(
[SamplePoints(1024, sample_method="random")])
# It would be also possible to sample the farthest points:
# transformations = transforms.Compose([SamplePoints(1024, sample_method = "farthest_points")])
data = PointCloudDataSet(data_folder, train=True,
transform=transformations)
validation_percentage = 0.2
dataset_size = len(data)
idx = list(range(dataset_size))
split = int(np.floor(validation_percentage * dataset_size))
np.random.shuffle(idx)
train_idx, val_idx = idx[split:], idx[:split]
print(f"Training is done with {len(train_idx)} samples.")
print(f"Validation is done with {len(val_idx)} samples.")
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
train_loader = DataLoader(data, batch_size=batch_size, sampler=train_sampler)
val_loader = DataLoader(data, batch_size=batch_size, sampler=val_sampler)
train_loader, val_loader = create_dataloader(data_folder)
### MODEL ####
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment